3737from paddlenlp .transformers .linear_utils import Linear
3838from paddlenlp .transformers .model_outputs import BaseModelOutputWithPast , ModelOutput
3939from paddlenlp .transformers .model_utils import PretrainedModel
40+ from paddlenlp .utils .tools import get_env_device
4041
4142from paddlemix .models .flash_attn_utils import (
4243 create_attention_module ,
4849from .bert_padding import index_first_axis , pad_input , unpad_input
4950from .configuration_qwen2_vl import Qwen2VLConfig , Qwen2VLVisionConfig
5051
52+ try :
53+ from paddle .incubate .nn .functional import fused_rotary_position_embedding
54+ except ImportError :
55+ fused_rotary_position_embedding = None
56+
5157logger = logging .get_logger (__name__ )
5258
5359flash_attn_func , flash_attn_varlen_func = has_flash_attn_func ()
@@ -68,6 +74,15 @@ def get_triangle_upper_mask(x, mask=None):
6874
6975
7076def parallel_matmul (x : Tensor , y : Tensor , transpose_y = True , tensor_parallel_output = True ):
77+ if get_env_device () == "xpu" :
78+ try :
79+ import paddle_xpu .layers .nn .linear .parallel_matmul as xpu_parallel_matmul
80+ return xpu_parallel_matmul (x , y , transpose_y = transpose_y , tensor_parallel_output = True )
81+ except ImportError :
82+ raise NotImplementedError (
83+ f"Implementation of parallel_matmul is not available on xpu. Please install paddle_xpu to use this feature"
84+ )
85+
7186 is_fleet_init = True
7287 tensor_parallel_degree = 1
7388 try :
@@ -407,7 +422,12 @@ def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) ->
407422 sin = freqs .sin ()
408423 cos = cos .unsqueeze (1 ).tile (repeat_times = [1 , 1 , 2 ]).unsqueeze (0 ).astype (dtype = "float32" )
409424 sin = sin .unsqueeze (1 ).tile (repeat_times = [1 , 1 , 2 ]).unsqueeze (0 ).astype (dtype = "float32" )
410- output = tensor * cos + rotate_half (tensor ) * sin
425+ if get_env_device () == "xpu" and fused_rotary_position_embedding is not None :
426+ output , _ , _ = fused_rotary_position_embedding (
427+ tensor , sin = sin , cos = cos , use_neox_rotary_style = False
428+ )
429+ else :
430+ output = tensor * cos + rotate_half (tensor ) * sin
411431 output = paddle .cast (output , orig_dtype )
412432 return output
413433
@@ -463,6 +483,12 @@ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> N
463483 nn .GELU (),
464484 nn .Linear (self .hidden_size , dim ),
465485 )
486+ if get_env_device () == "xpu" :
487+ self .mlp = nn .Sequential (
488+ Linear (self .hidden_size , self .hidden_size ),
489+ nn .GELU (),
490+ Linear (self .hidden_size , dim ),
491+ )
466492
467493 def forward (self , x : paddle .Tensor ) -> paddle .Tensor :
468494 x = self .mlp (self .ln_q (x ).reshape ([- 1 , self .hidden_size ]))
@@ -475,6 +501,9 @@ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
475501 self .fc1 = nn .Linear (dim , hidden_dim )
476502 self .act = ACT2FN [hidden_act ]
477503 self .fc2 = nn .Linear (hidden_dim , dim )
504+ if get_env_device () == "xpu" :
505+ self .fc1 = Linear (dim , hidden_dim )
506+ self .fc2 = Linear (hidden_dim , dim )
478507
479508 def forward (self , x ) -> paddle .Tensor :
480509 return self .fc2 (self .act (self .fc1 (x )))
@@ -486,6 +515,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
486515 self .num_heads = num_heads
487516 self .qkv = nn .Linear (dim , dim * 3 , bias_attr = True )
488517 self .proj = nn .Linear (dim , dim )
518+ if get_env_device () == "xpu" :
519+ self .qkv = Linear (dim , dim * 3 , bias_attr = True )
520+ self .proj = Linear (dim , dim )
489521 self .head_dim = dim // num_heads # must added
490522
491523 def forward (
@@ -525,6 +557,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
525557 self .num_heads = num_heads
526558 self .qkv = nn .Linear (dim , dim * 3 , bias_attr = True )
527559 self .proj = nn .Linear (dim , dim )
560+ if get_env_device () == "xpu" :
561+ self .qkv = Linear (dim , dim * 3 , bias_attr = True )
562+ self .proj = Linear (dim , dim )
528563 self .head_dim = dim // num_heads # must added
529564
530565 def forward (
@@ -657,6 +692,15 @@ def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6):
657692 self .variance_epsilon = eps
658693
659694 def forward (self , hidden_states ):
695+ if get_env_device () == "xpu" :
696+ try :
697+ import paddle_xpu_nn # noqa: F821
698+
699+ return paddle_xpu_nn .xpu_rms_norm (hidden_states , self .weight , self .variance_epsilon )[0 ]
700+ except ImportError :
701+ raise NotImplementedError (
702+ f"Implementation of fused_rms_norm is not available on xpu. Please install paddle_xpu to use this feature"
703+ )
660704 if paddle .in_dynamic_mode ():
661705 with paddle .amp .auto_cast (False ):
662706 variance = hidden_states .astype ("float32" ).pow (2 ).mean (- 1 , keepdim = True )
@@ -1193,7 +1237,7 @@ class Qwen2VLPreTrainedModel(PretrainedModel):
11931237
11941238 def _init_weights (self , layer ):
11951239 std = 0.2
1196- if isinstance (layer , (nn .Linear , nn .Conv3D )):
1240+ if isinstance (layer , (nn .Linear , nn .Conv3D , Linear )):
11971241 nn .initializer .Normal (mean = 0.0 , std = std )(layer .weight )
11981242 if layer .bias is not None :
11991243 nn .initializer .Constant (0.0 )(layer .bias )
0 commit comments