2424from lightllm .models .qwen2_vl .qwen2_visual import PatchEmbed , VisionRotaryEmbedding
2525from lightllm .models .vit .triton_kernel .flashattention_nopad import flash_attention_fwd
2626from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
27- from vllm .vllm_flash_attn .layers .rotary import apply_rotary_emb
2827
2928# adapted from
3029# https://github.com/huggingface/transformers/blob/
3130# be37d34f44ff1bc928e59ffb8a30adecab8835a8/src
3231# /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1
32+ class Qwen2_5_VLVisionConfig (PretrainedConfig ):
33+ model_type = "qwen2_5_vl"
34+
35+ def __init__ (
36+ self ,
37+ depth = 32 ,
38+ hidden_size = 3584 ,
39+ hidden_act = "silu" ,
40+ intermediate_size = 3420 ,
41+ num_heads = 16 ,
42+ in_channels = 3 ,
43+ patch_size = 14 ,
44+ spatial_merge_size = 2 ,
45+ temporal_patch_size = 2 ,
46+ tokens_per_second = 4 ,
47+ window_size = 112 ,
48+ out_hidden_size = 3584 ,
49+ fullatt_block_indexes = [7 , 15 , 23 , 31 ],
50+ ** kwargs ,
51+ ):
52+ super ().__init__ (** kwargs )
53+
54+ self .depth = depth
55+ self .hidden_size = hidden_size
56+ self .hidden_act = hidden_act
57+ self .intermediate_size = intermediate_size
58+ self .num_heads = num_heads
59+ self .in_channels = in_channels
60+ self .patch_size = patch_size
61+ self .spatial_merge_size = spatial_merge_size
62+ self .temporal_patch_size = temporal_patch_size
63+ self .tokens_per_second = tokens_per_second
64+ self .window_size = window_size
65+ self .fullatt_block_indexes = fullatt_block_indexes
66+ self .out_hidden_size = out_hidden_size
67+
68+
3369class Qwen2RMSNorm (nn .Module ):
3470 def __init__ (self , hidden_size , eps = 1e-6 ):
3571 """
@@ -68,6 +104,27 @@ def forward(self, hidden_state):
68104 return self .down_proj (self .act_fn (self .gate_proj (hidden_state )) * self .up_proj (hidden_state ))
69105
70106
107+ def rotate_half (x ):
108+ """Rotates half the hidden dims of the input."""
109+ x1 = x [..., : x .shape [- 1 ] // 2 ]
110+ x2 = x [..., x .shape [- 1 ] // 2 :]
111+ return torch .cat ((- x2 , x1 ), dim = - 1 )
112+
113+
114+ def apply_rotary_pos_emb_vision (
115+ q : torch .Tensor , k : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor
116+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
117+ orig_q_dtype = q .dtype
118+ orig_k_dtype = k .dtype
119+ q , k = q .float (), k .float ()
120+ cos , sin = cos .unsqueeze (- 2 ).float (), sin .unsqueeze (- 2 ).float ()
121+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
122+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
123+ q_embed = q_embed .to (orig_q_dtype )
124+ k_embed = k_embed .to (orig_k_dtype )
125+ return q_embed , k_embed
126+
127+
71128class Qwen2_5_VLVisionFlashAttention (nn .Module ):
72129 def __init__ (self , dim : int , num_heads : int = 16 ) -> None :
73130 super ().__init__ ()
@@ -76,27 +133,26 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
76133 self .qkv = nn .Linear (dim , dim * 3 , bias = True )
77134 self .proj = nn .Linear (dim , dim )
78135
79- def apply_rotary_pos_emb_vision (self , t : torch .Tensor , freqs : torch .Tensor ) -> torch .Tensor :
80- t_ = t .float ()
81- cos = freqs .cos ()
82- sin = freqs .sin ()
83- output = apply_rotary_emb (t_ , cos , sin ).type_as (t )
84- return output
85-
86136 def forward (
87137 self ,
88138 hidden_states : torch .Tensor ,
89139 cu_seqlens : torch .Tensor ,
90- max_seqlen : int = 0 ,
91140 rotary_pos_emb : Optional [torch .Tensor ] = None ,
141+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
92142 ) -> torch .Tensor :
93143 seq_length = hidden_states .shape [0 ]
94144 q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
95- q = self .apply_rotary_pos_emb_vision (q , rotary_pos_emb )
96- k = self .apply_rotary_pos_emb_vision (k , rotary_pos_emb )
145+ if position_embeddings is None :
146+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
147+ cos = emb .cos ()
148+ sin = emb .sin ()
149+ else :
150+ cos , sin = position_embeddings
151+ q , k = apply_rotary_pos_emb_vision (q , k , cos , sin )
97152
153+ cu_seqlens = cu_seqlens .to (q .device , torch .int32 )
154+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
98155 attn_output = g_cache_manager .alloc_tensor (q .shape , q .dtype , device = q .device )
99-
100156 flash_attention_fwd (q , k , v , attn_output , cu_seqlens , max_seqlen )
101157 attn_output = attn_output .reshape (seq_length , - 1 )
102158 attn_output = self .proj (attn_output )
@@ -127,14 +183,12 @@ def forward(
127183 self ,
128184 hidden_states : torch .Tensor ,
129185 cu_seqlens : torch .Tensor ,
130- max_seqlen : int = 0 ,
131186 rotary_pos_emb : Optional [torch .Tensor ] = None ,
132187 position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
133188 ) -> torch .Tensor :
134189 hidden_states = hidden_states + self .attn (
135190 self .norm1 (hidden_states ),
136191 cu_seqlens = cu_seqlens ,
137- max_seqlen = max_seqlen ,
138192 rotary_pos_emb = rotary_pos_emb ,
139193 position_embeddings = position_embeddings ,
140194 )
@@ -178,7 +232,6 @@ def __init__(
178232 ** kwargs ,
179233 ):
180234 super ().__init__ ()
181- self .data_type = kvargs .get ("data_type" , "bfloat16" )
182235
183236 self .depth = depth
184237 self .hidden_size = hidden_size
@@ -204,7 +257,7 @@ def __init__(
204257 )
205258
206259 head_dim = self .hidden_size // self .num_heads
207- self .rotary_pos_emb = VisionRotaryEmbedding (head_dim // 2 ).to ("cuda" , non_blocking = True )
260+ self .rotary_pos_emb = VisionRotaryEmbedding (head_dim // 2 ).to ("cuda" , dtype = self . get_dtype (), non_blocking = True )
208261
209262 self .blocks = nn .ModuleList (
210263 [
@@ -226,62 +279,41 @@ def __init__(
226279
227280 self .gradient_checkpointing = False
228281
229- self ._init_datatype ()
230-
231- def _init_datatype (self ):
232- if isinstance (self .data_type , torch .dtype ):
233- return
234- if self .data_type in ["fp16" , "float16" ]:
235- self .data_type = torch .float16
236- elif self .data_type in ["bf16" , "bfloat16" ]:
237- self .data_type = torch .bfloat16
238- elif self .data_type in ["fp32" , "float32" ]:
239- self .data_type = torch .float32
240- else :
241- raise ValueError (f"Unsupport datatype { self .data_type } !" )
242- return
282+ self .device = self .get_device ()
283+ self .dtype = self .get_dtype ()
243284
244- def load_model (self , weight_dir ):
285+ def get_dtype (self ) -> torch .dtype :
286+ return self .blocks [0 ].mlp .down_proj .weight .dtype
245287
246- processor_config_path = os .path .join (weight_dir , "preprocessor_config.json" )
247- with open (processor_config_path , "r" ) as f :
248- processor_config_dict = json .load (f )
249- self .processor = Qwen2VLImageProcessor (** processor_config_dict )
250-
251- bin_weight_files = [file_ for file_ in os .listdir (weight_dir ) if file_ .endswith (".bin" )]
252- if bin_weight_files :
253- weight_dict = {}
254- for file_ in bin_weight_files :
255- f = torch .load (os .path .join (weight_dir , file_ ), "cpu" )
256- for k , v in f .items ():
257- if "visual" in k :
258- weight_dict [k [len ("visual." ) :]] = v
259-
260- else :
261- hf_weight_files = [file_ for file_ in os .listdir (weight_dir ) if file_ .endswith (".safetensors" )]
262- weight_dict = {}
263- for file_ in hf_weight_files :
264- f = safe_open (os .path .join (weight_dir , file_ ), "pt" , "cpu" )
265- for k in f .keys ():
266- if "visual" in k :
267- weight_dict [k [len ("visual." ) :]] = f .get_tensor (k )
268-
269- self .load_state_dict (weight_dict )
288+ def get_device (self ) -> torch .device :
289+ return self .blocks [0 ].mlp .down_proj .weight .device
270290
271291 def rot_pos_emb (self , grid_thw ):
272292 pos_ids = []
273- s = self .spatial_merge_size
274- for _ , h , w in grid_thw :
275- pos_shape = (h // s , s , w // s , s )
293+ for t , h , w in grid_thw :
276294 hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
277- wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
278- hpos_ids = hpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
279- wpos_ids = wpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
295+ hpos_ids = hpos_ids .reshape (
296+ h // self .spatial_merge_size ,
297+ self .spatial_merge_size ,
298+ w // self .spatial_merge_size ,
299+ self .spatial_merge_size ,
300+ )
301+ hpos_ids = hpos_ids .permute (0 , 2 , 1 , 3 )
302+ hpos_ids = hpos_ids .flatten ()
280303
281- pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ))
304+ wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
305+ wpos_ids = wpos_ids .reshape (
306+ h // self .spatial_merge_size ,
307+ self .spatial_merge_size ,
308+ w // self .spatial_merge_size ,
309+ self .spatial_merge_size ,
310+ )
311+ wpos_ids = wpos_ids .permute (0 , 2 , 1 , 3 )
312+ wpos_ids = wpos_ids .flatten ()
313+ pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
282314 pos_ids = torch .cat (pos_ids , dim = 0 )
283315 max_grid_size = grid_thw [:, 1 :].max ()
284- rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size ). type ( torch . float32 )
316+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
285317 rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
286318 return rotary_pos_emb
287319
@@ -328,7 +360,7 @@ def get_window_index(self, grid_thw):
328360
329361 def forward (self , hidden_states : torch .Tensor , grid_thw : torch .Tensor ) -> torch .Tensor :
330362 hidden_states = self .patch_embed (hidden_states )
331- rotary_pos_emb = self .rot_pos_emb (grid_thw ). to ( "cuda" , non_blocking = True )
363+ rotary_pos_emb = self .rot_pos_emb (grid_thw )
332364 window_index , cu_window_seqlens = self .get_window_index (grid_thw )
333365 cu_window_seqlens = torch .tensor (
334366 cu_window_seqlens ,
@@ -344,14 +376,20 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
344376 rotary_pos_emb = rotary_pos_emb .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
345377 rotary_pos_emb = rotary_pos_emb [window_index , :, :]
346378 rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
347- emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
379+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 ).to ("cuda" , dtype = self .get_dtype (), non_blocking = True )
380+ position_embeddings = (emb .cos (), emb .sin ())
348381
349382 cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).cumsum (
350383 dim = 0 ,
384+ # Select dtype based on the following factors:
385+ # - FA2 requires that cu_seqlens_q must have dtype int32
386+ # - torch.onnx.export requires that cu_seqlens_q must have same
387+ # dtype as grid_thw
388+ # See https://github.com/huggingface/transformers/pull/34852
389+ # for more information
351390 dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 ,
352391 )
353392 cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
354- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
355393
356394 for layer_num , blk in enumerate (self .blocks ):
357395 if layer_num in self .fullatt_block_indexes :
@@ -361,8 +399,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
361399 hidden_states = blk (
362400 hidden_states ,
363401 cu_seqlens = cu_seqlens_now ,
364- max_seqlen = max_seqlen ,
365- position_embeddings = emb ,
402+ position_embeddings = position_embeddings ,
366403 )
367404
368405 hidden_states = self .merger (hidden_states )
@@ -371,6 +408,33 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
371408
372409 return hidden_states
373410
411+ def load_model (self , weight_dir ):
412+
413+ processor_config_path = os .path .join (weight_dir , "preprocessor_config.json" )
414+ with open (processor_config_path , "r" ) as f :
415+ processor_config_dict = json .load (f )
416+ self .processor = Qwen2VLImageProcessor (** processor_config_dict )
417+
418+ bin_weight_files = [file_ for file_ in os .listdir (weight_dir ) if file_ .endswith (".bin" )]
419+ if bin_weight_files :
420+ weight_dict = {}
421+ for file_ in bin_weight_files :
422+ f = torch .load (os .path .join (weight_dir , file_ ), "cpu" )
423+ for k , v in f .items ():
424+ if "visual" in k :
425+ weight_dict [k [len ("visual." ) :]] = v
426+
427+ else :
428+ hf_weight_files = [file_ for file_ in os .listdir (weight_dir ) if file_ .endswith (".safetensors" )]
429+ weight_dict = {}
430+ for file_ in hf_weight_files :
431+ f = safe_open (os .path .join (weight_dir , file_ ), "pt" , "cpu" )
432+ for k in f .keys ():
433+ if "visual" in k :
434+ weight_dict [k [len ("visual." ) :]] = f .get_tensor (k )
435+
436+ self .load_state_dict (weight_dict )
437+
374438 def encode (self , images : List [ImageItem ]):
375439 img_tensors = []
376440 valid_ids = []
@@ -402,7 +466,7 @@ def encode(self, images: List[ImageItem]):
402466 imgs = torch .cat (img_tensors , dim = 0 )
403467 grid_thw = torch .cat (img_grids , dim = 0 )
404468
405- pixel_values = imgs .to ("cuda" , dtype = self .data_type , non_blocking = True )
469+ pixel_values = imgs .to ("cuda" , dtype = self .get_dtype () , non_blocking = True )
406470 image_grid_thw = grid_thw .to ("cuda" , non_blocking = True )
407471
408472 all_img_embeds = self .forward (pixel_values , grid_thw = image_grid_thw )
0 commit comments