77from lightllm .models .qwen2_vl .triton_kernel .mrope import mrope_triton
88from lightllm .models .llama .layer_infer .transformer_layer_infer import LlamaTransformerLayerInfer
99
10-
1110class Qwen2VLTransformerLayerInfer (LlamaTransformerLayerInfer ):
1211 def __init__ (self , layer_num , network_config , mode = []):
1312 super ().__init__ (layer_num , network_config , mode )
@@ -19,10 +18,10 @@ def _get_qkv(self, input, cache_kv, infer_state, layer_weight):
1918 input , out = cache_kv .view (- 1 , (self .tp_k_head_num_ + self .tp_v_head_num_ ) * self .head_dim_ )
2019 ).view (- 1 , (self .tp_k_head_num_ + self .tp_v_head_num_ ), self .head_dim_ )
2120 seq_len , _ = q .shape
22- q = q .view (1 , seq_len , - 1 , self .head_dim_ ).transpose (1 , 2 )
23- k = cache_kv [:, : self .tp_k_head_num_ , :].view (1 , seq_len , - 1 , self .head_dim_ ).transpose (1 , 2 )
21+ q = q .view (1 , seq_len , - 1 , self .head_dim_ ).transpose (1 , 2 ). contiguous ()
22+ k = cache_kv [:, : self .tp_k_head_num_ , :].view (1 , seq_len , - 1 , self .head_dim_ ).transpose (1 , 2 ). contiguous ()
2423 new_q , new_k = mrope_triton (q , k , infer_state .position_cos , infer_state .position_sin , self .mrope_section )
25- new_q = new_q .transpose (1 , 2 ).reshape (1 , seq_len , - 1 )
24+ new_q = new_q .transpose (1 , 2 ).reshape (1 , seq_len , - 1 ). contiguous ()
2625 cache_kv [:, : self .tp_k_head_num_ , :] = new_k .squeeze (0 ).permute (1 , 0 , 2 )
2726
2827 return new_q , cache_kv
0 commit comments