11import os
2- import re
32import json
43import torch
54import torch .nn .functional as F
65from PIL import Image
7- from typing import Any , Dict , List , Optional , Tuple , Union
8- from torchvision import transforms as T
9- from torchvision .transforms .functional import InterpolationMode
10- from transformers import AutoModel , AutoTokenizer
6+ from typing import List , Optional
117from lightllm .server .embed_cache .utils import read_shm , get_shm_name_data
128from io import BytesIO
13- from transformers .configuration_utils import PretrainedConfig
14- from transformers .modeling_utils import PreTrainedModel
159import torch .nn as nn
16- from torch .nn import LayerNorm
1710from transformers .activations import ACT2FN
18- import math
1911from lightllm .models .qwen2_vl .vision_process import resize_image , Qwen2VLImageProcessor
20- from transformers import AutoProcessor
2112from safetensors import safe_open
22- from transformers .utils import TensorType
23- from lightllm .server .multimodal_params import MultimodalParams , ImageItem
13+ from lightllm .server .multimodal_params import ImageItem
2414from lightllm .models .qwen2_vl .qwen2_visual import PatchEmbed , VisionRotaryEmbedding
15+ from lightllm .models .vit .triton_kernel .rms_norm_vit import rms_norm
2516from lightllm .models .vit .triton_kernel .flashattention_nopad import flash_attention_fwd
2617from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
27-
28- # adapted from
29- # https://github.com/huggingface/transformers/blob/
30- # be37d34f44ff1bc928e59ffb8a30adecab8835a8/src
31- # /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1
32- class Qwen2_5_VLVisionConfig (PretrainedConfig ):
33- model_type = "qwen2_5_vl"
34-
35- def __init__ (
36- self ,
37- depth = 32 ,
38- hidden_size = 3584 ,
39- hidden_act = "silu" ,
40- intermediate_size = 3420 ,
41- num_heads = 16 ,
42- in_channels = 3 ,
43- patch_size = 14 ,
44- spatial_merge_size = 2 ,
45- temporal_patch_size = 2 ,
46- tokens_per_second = 4 ,
47- window_size = 112 ,
48- out_hidden_size = 3584 ,
49- fullatt_block_indexes = [7 , 15 , 23 , 31 ],
50- ** kwargs ,
51- ):
52- super ().__init__ (** kwargs )
53-
54- self .depth = depth
55- self .hidden_size = hidden_size
56- self .hidden_act = hidden_act
57- self .intermediate_size = intermediate_size
58- self .num_heads = num_heads
59- self .in_channels = in_channels
60- self .patch_size = patch_size
61- self .spatial_merge_size = spatial_merge_size
62- self .temporal_patch_size = temporal_patch_size
63- self .tokens_per_second = tokens_per_second
64- self .window_size = window_size
65- self .fullatt_block_indexes = fullatt_block_indexes
66- self .out_hidden_size = out_hidden_size
18+ from lightllm .models .qwen2_vl .triton_kernel .rotary_pos_emb import apply_rotary_pos_emb_triton
6719
6820
6921class Qwen2RMSNorm (nn .Module ):
@@ -76,11 +28,7 @@ def __init__(self, hidden_size, eps=1e-6):
7628 self .variance_epsilon = eps
7729
7830 def forward (self , hidden_states ):
79- input_dtype = hidden_states .dtype
80- hidden_states = hidden_states .to (torch .float32 )
81- variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
82- hidden_states = hidden_states * torch .rsqrt (variance + self .variance_epsilon )
83- return self .weight * hidden_states .to (input_dtype )
31+ return rms_norm (hidden_states , self .weight , eps = self .variance_epsilon )
8432
8533 def extra_repr (self ):
8634 return f"{ tuple (self .weight .shape )} , eps={ self .variance_epsilon } "
@@ -104,27 +52,6 @@ def forward(self, hidden_state):
10452 return self .down_proj (self .act_fn (self .gate_proj (hidden_state )) * self .up_proj (hidden_state ))
10553
10654
107- def rotate_half (x ):
108- """Rotates half the hidden dims of the input."""
109- x1 = x [..., : x .shape [- 1 ] // 2 ]
110- x2 = x [..., x .shape [- 1 ] // 2 :]
111- return torch .cat ((- x2 , x1 ), dim = - 1 )
112-
113-
114- def apply_rotary_pos_emb_vision (
115- q : torch .Tensor , k : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor
116- ) -> Tuple [torch .Tensor , torch .Tensor ]:
117- orig_q_dtype = q .dtype
118- orig_k_dtype = k .dtype
119- q , k = q .float (), k .float ()
120- cos , sin = cos .unsqueeze (- 2 ).float (), sin .unsqueeze (- 2 ).float ()
121- q_embed = (q * cos ) + (rotate_half (q ) * sin )
122- k_embed = (k * cos ) + (rotate_half (k ) * sin )
123- q_embed = q_embed .to (orig_q_dtype )
124- k_embed = k_embed .to (orig_k_dtype )
125- return q_embed , k_embed
126-
127-
12855class Qwen2_5_VLVisionFlashAttention (nn .Module ):
12956 def __init__ (self , dim : int , num_heads : int = 16 ) -> None :
13057 super ().__init__ ()
@@ -137,21 +64,16 @@ def forward(
13764 self ,
13865 hidden_states : torch .Tensor ,
13966 cu_seqlens : torch .Tensor ,
67+ max_seqlen : int = 0 ,
14068 rotary_pos_emb : Optional [torch .Tensor ] = None ,
141- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
14269 ) -> torch .Tensor :
14370 seq_length = hidden_states .shape [0 ]
14471 q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
145- if position_embeddings is None :
146- emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
147- cos = emb .cos ()
148- sin = emb .sin ()
149- else :
150- cos , sin = position_embeddings
151- q , k = apply_rotary_pos_emb_vision (q , k , cos , sin )
72+ q = apply_rotary_pos_emb_triton (q .unsqueeze (0 ), rotary_pos_emb .cos (), rotary_pos_emb .sin ())
73+ k = apply_rotary_pos_emb_triton (k .unsqueeze (0 ), rotary_pos_emb .cos (), rotary_pos_emb .sin ())
74+ q = q .squeeze (0 )
75+ k = k .squeeze (0 )
15276
153- cu_seqlens = cu_seqlens .to (q .device , torch .int32 )
154- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
15577 attn_output = g_cache_manager .alloc_tensor (q .shape , q .dtype , device = q .device )
15678 flash_attention_fwd (q , k , v , attn_output , cu_seqlens , max_seqlen )
15779 attn_output = attn_output .reshape (seq_length , - 1 )
@@ -183,14 +105,14 @@ def forward(
183105 self ,
184106 hidden_states : torch .Tensor ,
185107 cu_seqlens : torch .Tensor ,
108+ max_seqlen : int = 0 ,
186109 rotary_pos_emb : Optional [torch .Tensor ] = None ,
187- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
188110 ) -> torch .Tensor :
189111 hidden_states = hidden_states + self .attn (
190112 self .norm1 (hidden_states ),
191113 cu_seqlens = cu_seqlens ,
114+ max_seqlen = max_seqlen ,
192115 rotary_pos_emb = rotary_pos_emb ,
193- position_embeddings = position_embeddings ,
194116 )
195117 hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
196118 return hidden_states
@@ -232,6 +154,8 @@ def __init__(
232154 ** kwargs ,
233155 ):
234156 super ().__init__ ()
157+ self .weight_dir = kvargs ["weight_dir" ]
158+ self .data_type = kvargs .get ("data_type" , "bfloat16" )
235159
236160 self .depth = depth
237161 self .hidden_size = hidden_size
@@ -257,7 +181,7 @@ def __init__(
257181 )
258182
259183 head_dim = self .hidden_size // self .num_heads
260- self .rotary_pos_emb = VisionRotaryEmbedding (head_dim // 2 ). to ( "cuda" , dtype = self . get_dtype (), non_blocking = True )
184+ self .rotary_pos_emb = VisionRotaryEmbedding (head_dim // 2 )
261185
262186 self .blocks = nn .ModuleList (
263187 [
@@ -279,41 +203,42 @@ def __init__(
279203
280204 self .gradient_checkpointing = False
281205
282- self .device = self .get_device ()
283- self .dtype = self .get_dtype ()
284-
285- def get_dtype (self ) -> torch .dtype :
286- return self .blocks [0 ].mlp .down_proj .weight .dtype
206+ processor_config_path = os .path .join (self .weight_dir , "preprocessor_config.json" )
207+ with open (processor_config_path , "r" ) as f :
208+ processor_config_dict = json .load (f )
209+ self .processor = Qwen2VLImageProcessor (** processor_config_dict )
287210
288- def get_device (self ) -> torch .device :
289- return self .blocks [0 ].mlp .down_proj .weight .device
211+ self ._init_datatype ()
212+ self .load_model (kvargs ["weight_dir" ])
213+ self .cuda ()
214+
215+ def _init_datatype (self ):
216+ if isinstance (self .data_type , torch .dtype ):
217+ return
218+ if self .data_type in ["fp16" , "float16" ]:
219+ self .data_type = torch .float16
220+ elif self .data_type in ["bf16" , "bfloat16" ]:
221+ self .data_type = torch .bfloat16
222+ elif self .data_type in ["fp32" , "float32" ]:
223+ self .data_type = torch .float32
224+ else :
225+ raise ValueError (f"Unsupport datatype { self .data_type } !" )
226+ return
290227
291228 def rot_pos_emb (self , grid_thw ):
292229 pos_ids = []
293- for t , h , w in grid_thw :
230+ s = self .spatial_merge_size
231+ for _ , h , w in grid_thw :
232+ pos_shape = (h // s , s , w // s , s )
294233 hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
295- hpos_ids = hpos_ids .reshape (
296- h // self .spatial_merge_size ,
297- self .spatial_merge_size ,
298- w // self .spatial_merge_size ,
299- self .spatial_merge_size ,
300- )
301- hpos_ids = hpos_ids .permute (0 , 2 , 1 , 3 )
302- hpos_ids = hpos_ids .flatten ()
303-
304234 wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
305- wpos_ids = wpos_ids .reshape (
306- h // self .spatial_merge_size ,
307- self .spatial_merge_size ,
308- w // self .spatial_merge_size ,
309- self .spatial_merge_size ,
310- )
311- wpos_ids = wpos_ids .permute (0 , 2 , 1 , 3 )
312- wpos_ids = wpos_ids .flatten ()
313- pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
235+ hpos_ids = hpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
236+ wpos_ids = wpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
237+
238+ pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ))
314239 pos_ids = torch .cat (pos_ids , dim = 0 )
315240 max_grid_size = grid_thw [:, 1 :].max ()
316- rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
241+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size ). type ( torch . float32 )
317242 rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
318243 return rotary_pos_emb
319244
@@ -360,14 +285,22 @@ def get_window_index(self, grid_thw):
360285
361286 def forward (self , hidden_states : torch .Tensor , grid_thw : torch .Tensor ) -> torch .Tensor :
362287 hidden_states = self .patch_embed (hidden_states )
363- rotary_pos_emb = self .rot_pos_emb (grid_thw )
288+ rotary_pos_emb = self .rot_pos_emb (grid_thw ).to ("cuda" , non_blocking = True )
289+ cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).cumsum (
290+ dim = 0 , dtype = torch .int32
291+ )
292+ cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
293+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
294+ cu_seqlens = cu_seqlens .to ("cuda" , non_blocking = True )
295+
364296 window_index , cu_window_seqlens = self .get_window_index (grid_thw )
365297 cu_window_seqlens = torch .tensor (
366298 cu_window_seqlens ,
367299 device = hidden_states .device ,
368300 dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 ,
369301 )
370302 cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
303+ max_window_seqlen = (cu_window_seqlens [1 :] - cu_window_seqlens [:- 1 ]).max ().item ()
371304
372305 seq_len , _ = hidden_states .size ()
373306 hidden_states = hidden_states .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
@@ -376,30 +309,20 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
376309 rotary_pos_emb = rotary_pos_emb .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
377310 rotary_pos_emb = rotary_pos_emb [window_index , :, :]
378311 rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
379- emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 ).to ("cuda" , dtype = self .get_dtype (), non_blocking = True )
380- position_embeddings = (emb .cos (), emb .sin ())
381-
382- cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).cumsum (
383- dim = 0 ,
384- # Select dtype based on the following factors:
385- # - FA2 requires that cu_seqlens_q must have dtype int32
386- # - torch.onnx.export requires that cu_seqlens_q must have same
387- # dtype as grid_thw
388- # See https://github.com/huggingface/transformers/pull/34852
389- # for more information
390- dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 ,
391- )
392- cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
393312
394313 for layer_num , blk in enumerate (self .blocks ):
395314 if layer_num in self .fullatt_block_indexes :
396315 cu_seqlens_now = cu_seqlens
316+ max_seqlen_now = max_seqlen
397317 else :
398318 cu_seqlens_now = cu_window_seqlens
319+ max_seqlen_now = max_window_seqlen
320+
399321 hidden_states = blk (
400322 hidden_states ,
401323 cu_seqlens = cu_seqlens_now ,
402- position_embeddings = position_embeddings ,
324+ max_seqlen = max_seqlen_now ,
325+ rotary_pos_emb = rotary_pos_emb ,
403326 )
404327
405328 hidden_states = self .merger (hidden_states )
@@ -408,12 +331,23 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
408331
409332 return hidden_states
410333
411- def load_model (self , weight_dir ):
334+ def load_image (self , img : List [ImageItem ]):
335+ pixel_values = None
336+ if isinstance (img , ImageItem ):
337+ image_data = read_shm (get_shm_name_data (img .uuid ))
338+ image_data = Image .open (BytesIO (image_data ))
339+ image_data = resize_image (image_data )
340+ pixel_values , image_grid_thw = self .processor .preprocess (image_data )
341+ elif isinstance (img , dict ):
342+ image_data = read_shm (get_shm_name_data (img ["uuid" ]))
343+ image_data = Image .open (BytesIO (image_data ))
344+ image_data = resize_image (image_data )
345+ pixel_values , image_grid_thw = self .processor .preprocess (image_data )
346+ else :
347+ raise Exception ("Unsupport input types: {} for {}" .format (type (img ), img ))
348+ return pixel_values .to (dtype = self .data_type ), image_grid_thw
412349
413- processor_config_path = os .path .join (weight_dir , "preprocessor_config.json" )
414- with open (processor_config_path , "r" ) as f :
415- processor_config_dict = json .load (f )
416- self .processor = Qwen2VLImageProcessor (** processor_config_dict )
350+ def load_model (self , weight_dir ):
417351
418352 bin_weight_files = [file_ for file_ in os .listdir (weight_dir ) if file_ .endswith (".bin" )]
419353 if bin_weight_files :
@@ -466,7 +400,7 @@ def encode(self, images: List[ImageItem]):
466400 imgs = torch .cat (img_tensors , dim = 0 )
467401 grid_thw = torch .cat (img_grids , dim = 0 )
468402
469- pixel_values = imgs .to ("cuda" , dtype = self .get_dtype () , non_blocking = True )
403+ pixel_values = imgs .to ("cuda" , dtype = self .data_type , non_blocking = True )
470404 image_grid_thw = grid_thw .to ("cuda" , non_blocking = True )
471405
472406 all_img_embeds = self .forward (pixel_values , grid_thw = image_grid_thw )
0 commit comments