1+ from typing import Optional , List
12import torch
23import numpy as np
34from lightllm .models .llama .infer_struct import LlamaInferStateInfo
45from lightllm .common .basemodel .infer_struct import InferStateInfo
6+ from lightllm .models .qwen2_vl .triton_kernel .get_mrope_position_ids import get_mrope_position_triton
7+ from lightllm .models .llama .flashattention_infer_struct import FlashAttentionStateInfo
8+ from lightllm .utils .envs_utils import get_env_start_args
59
610
711class Qwen2VLInferStateInfo (LlamaInferStateInfo ):
12+ init_flash_attention_state_func = FlashAttentionStateInfo ._init_flash_attention_state
13+
814 def __init__ (self ):
915 super ().__init__ ()
1016 self .position_cos = None
@@ -13,17 +19,64 @@ def __init__(self):
1319 def init_some_extra_state (self , model , input_ids : torch .Tensor ):
1420 rope_scaling = model .config .get ("rope_scaling" , {})
1521 self .rope_type = rope_scaling .get ("rope_type" , rope_scaling .get ("type" , None ))
16- if self .rope_type != "mrope" :
17- super ().init_some_extra_state (model , input_ids )
18- return
1922 InferStateInfo .init_some_extra_state (self , model , input_ids )
2023 if self .is_prefill :
21- position_ids = self .position_ids
22- self .position_sin = model ._sin_cached [:, position_ids , :].unsqueeze (1 )
23- self .position_cos = model ._cos_cached [:, position_ids , :].unsqueeze (1 )
24- position_ids = None
24+ self .position_ids = self .get_mrope_position (self .multimodal_params )
2525 else :
26- position_ids = self .position_ids
27- self .position_sin = model ._sin_cached [:, position_ids , :].unsqueeze (1 )
28- self .position_cos = model ._cos_cached [:, position_ids , :].unsqueeze (1 )
26+ b_position_delta = [0 for _ in range (self .b_seq_len .shape [0 ])]
27+ for batch_idx , p in enumerate (self .multimodal_params ):
28+ position_delta = 0
29+ for image in p ["images" ]:
30+ position_delta += image ["grid_thwd" ][3 ]
31+ b_position_delta [batch_idx ] = position_delta
32+ position_ids = self .position_ids + torch .tensor (b_position_delta , device = self .position_ids .device )
33+ self .position_ids = position_ids .unsqueeze (0 ).expand (3 , - 1 )
34+
35+ self .position_ids = self .position_ids .contiguous ()
36+ self .position_cos = model ._cos_cached [self .position_ids ] # (3, L, D)
37+ self .position_sin = model ._sin_cached [self .position_ids ] # (3, L, D)
38+ if get_env_start_args ().enable_fa3 :
39+ self .max_seq_len = self .max_kv_seq_len
40+ self .q_max_seq_len = self .max_q_seq_len
41+ self .init_flash_attention_state_func (model , input_ids )
2942 return
43+
44+ def get_mrope_position (self , multimodal_params : List [dict ]) -> torch .Tensor :
45+ if len (multimodal_params ) == 0 :
46+ return self .position_ids .unsqueeze (0 ).expand (3 , - 1 )
47+ b_image_start_idx = []
48+ b_image_nums = []
49+ b_image_start_num = []
50+ b_image_len = []
51+ image_start_num = 0
52+ b_image_thwd = []
53+ for _ , p in enumerate (multimodal_params ):
54+ images = p .get ("images" , [])
55+ for img in images :
56+ b_image_start_idx .append (img ["start_idx" ])
57+ b_image_len .append (img ["token_num" ])
58+ b_image_thwd .append (img ["grid_thwd" ])
59+ b_image_nums .append (len (images ))
60+ b_image_start_num .append (image_start_num )
61+ image_start_num += len (images )
62+ # 没有任何图片
63+ if image_start_num == 0 :
64+ return self .position_ids .unsqueeze (0 ).expand (3 , - 1 ).contiguous ()
65+ b_image_start_idx = torch .tensor (b_image_start_idx , device = "cpu" ).cuda (non_blocking = True )
66+ b_image_thwd = torch .tensor (b_image_thwd , device = "cpu" ).cuda (non_blocking = True ) # image_num x 4
67+ b_image_nums = torch .tensor (b_image_nums , device = "cpu" ).cuda (non_blocking = True )
68+ b_image_start_num = torch .tensor (b_image_start_num , device = "cpu" ).cuda (non_blocking = True )
69+ b_image_len = torch .tensor (b_image_len , device = self .position_ids .device )
70+ position_ids = self .position_ids .unsqueeze (0 ).expand (3 , - 1 ).contiguous ()
71+ get_mrope_position_triton (
72+ b_image_start_idx = b_image_start_idx ,
73+ b_image_thwd = b_image_thwd ,
74+ b_image_nums = b_image_nums ,
75+ b_image_start_num = b_image_start_num ,
76+ b_image_len = b_image_len ,
77+ position_ids = position_ids ,
78+ b_ready_cache_len = self .b_ready_cache_len ,
79+ b_q_seq_len = self .b_q_seq_len ,
80+ b_start_loc = self .b_start_loc ,
81+ )
82+ return position_ids
0 commit comments