@@ -112,16 +112,22 @@ class VisionBuckets:
112112 This class is used to bucket image tokens
113113 '''
114114
115- def __init__ (self , is_batch_based ):
116- self .is_batch_based = is_batch_based
115+ def __init__ (self , model ):
116+ self .is_batch_based = True
117117 envvar = os .environ .get ('VLLM_MULTIMODAL_BUCKETS' , "" )
118118 if envvar == 'None' :
119119 self .multimodal_buckets = None
120120 else :
121121 if envvar == "" :
122- if is_batch_based :
122+ if 'InternVLChatModel' in str (type (model )):
123+ multimodal_buckets = list (
124+ range (model .config .min_dynamic_patch ,
125+ model .config .max_dynamic_patch +
126+ 2 )) #As use_thumbnail is true
127+ elif 'Gemma3ForConditionalGeneration' in str (type (model )):
123128 multimodal_buckets = [1 , 2 , 4 , 8 ] # batch sizes for gemma3
124129 else :
130+ self .is_batch_based = False
125131 multimodal_buckets = [
126132 1600 , 3136 , 4096 , 6400 , 7744 , 9216 , 12544
127133 ]
@@ -159,9 +165,11 @@ def __call__(cls, *args, **kwargs):
159165
160166
161167def is_mm_optimized (model ):
162- return 'Gemma3ForConditionalGeneration' in str (type (model .model )) \
163- if hasattr (model , 'model' ) else \
164- 'Gemma3ForConditionalGeneration' in str (type (model ))
168+ mm_models = ['Gemma3ForConditionalGeneration' , 'InternVLChatModel' ]
169+
170+ return any (m in str (type (model .model )) for m in mm_models ) \
171+ if hasattr (model , 'model' ) \
172+ else any (m in str (type (model )) for m in mm_models )
165173
166174
167175def pad_flat_tensor (tensor , desired_size ):
@@ -345,6 +353,7 @@ def __init__(self, model, vllm_config, is_causal, sampler):
345353 model_config = getattr (self .model , "config" , None )
346354
347355 self .model_is_mrope = uses_mrope (model_config )
356+
348357 self .is_mm_optimized = is_mm_optimized (self .model )
349358 text_config = vllm_config .model_config .hf_config .get_text_config ()
350359 self .interleaved_sliding_window = getattr (
@@ -379,6 +388,12 @@ def __init__(self, model, vllm_config, is_causal, sampler):
379388 htorch .hpu .wrap_in_hpu_graph ( \
380389 self .model .multi_modal_projector , \
381390 disable_tensor_cache = True )
391+ if hasattr (self .model , 'vision_model' ):
392+ self .model .vision_model = htorch .hpu .wrap_in_hpu_graph (
393+ self .model .vision_model , disable_tensor_cache = True )
394+ if hasattr (self .model , 'mlp1' ):
395+ self .model .mlp1 = htorch .hpu .wrap_in_hpu_graph (
396+ self .model .mlp1 , disable_tensor_cache = True )
382397
383398 self ._rotary_embed_module = self ._get_rotary_embedding_module (
384399 self .model )
@@ -624,26 +639,30 @@ def compute_input_embeddings_for_mm_optimized(self, warmup_mode, **kwargs):
624639 vision_embeddings = self .model .get_multimodal_embeddings (** kwargs )
625640 inputs_embeds = self .model .get_input_embeddings (
626641 input_ids , vision_embeddings )
627-
628642 # TODO: In warmup, we need to warmup the model with dummy image data for
629643 # multimodal model for prompt, here instead of generating a dummy image,
630644 # we are just generating attn_mask for the images and pass with
631645 # attn_metadata, so we can reuse HPU graph without running
632646 # the whole vision tower.
633647 if vision_embeddings is not None or (
634648 warmup_mode and kwargs ['attn_metadata' ].is_prompt ):
635- input_ids = kwargs ['input_ids' ]
636- positions = kwargs ['positions' ]
637- kwargs = self .model .prepare_attn_masks (
638- mask_dtype = self .dtype ,
639- ** kwargs ,
640- )
641- kwargs ['input_ids' ] = input_ids
642- kwargs ['positions' ] = positions
649+ if hasattr (self .model , 'prepare_attn_masks' ):
650+ input_ids = kwargs ['input_ids' ]
651+ positions = kwargs ['positions' ]
652+ kwargs = self .model .prepare_attn_masks (
653+ mask_dtype = self .dtype ,
654+ ** kwargs ,
655+ )
656+ kwargs ['input_ids' ] = input_ids
657+ kwargs ['positions' ] = positions
658+ # done compute the visual tokens
659+ kwargs .pop ('pixel_values' , None )
660+ else :
661+ kwargs .pop ('pixel_values_flat' , None )
662+ kwargs .pop ("image_num_patches" , None )
663+ kwargs .pop ("image_token_id" , None )
643664
644665 kwargs .update ({'inputs_embeds' : inputs_embeds })
645- # done compute the visual tokens and others
646- kwargs .pop ('pixel_values' , None )
647666 kwargs .pop ("num_crops" , None )
648667 kwargs .pop ("graphed_multimodal_buckets" , None )
649668 return kwargs
@@ -699,7 +718,6 @@ def forward(self, *args, **kwargs):
699718 virtual_engine = 0
700719 if 'virtual_engine' in kwargs :
701720 virtual_engine = kwargs .pop ('virtual_engine' )
702-
703721 input_ids = kwargs ['input_ids' ]
704722 global_attn_masks = kwargs .pop ("global_attn_masks" ) \
705723 if kwargs .get ("global_attn_masks" ) else None
@@ -1080,6 +1098,8 @@ def __init__(
10801098 and not self .lora_config )
10811099 self .use_delayed_sampling = get_config (
10821100 ).use_delayed_sampling and can_use_delayed_sampling
1101+ self .mm_tokens_per_image = 1
1102+ self .image_token_id = 0
10831103
10841104 def _set_gc_threshold (self ) -> None :
10851105 """
@@ -1497,10 +1517,16 @@ def move_to_device(self, tensor):
14971517 non_blocking = True )
14981518
14991519 def add_vision_buckets_to_mrope_mm_optimized (self ):
1500- model = self .get_model ()
1501- self .is_mm_optimized = is_mm_optimized (model )
1520+ self .is_mm_optimized = is_mm_optimized (self .model )
15021521 if self .model_is_mrope or self .is_mm_optimized :
1503- model .vision_buckets = VisionBuckets (self .is_mm_optimized )
1522+ if hasattr (self .model .model .config , 'mm_tokens_per_image' ):
1523+ self .mm_tokens_per_image = \
1524+ self .model .model .config .mm_tokens_per_image
1525+ self .image_token_id = self .model .model .config .image_token_id
1526+ elif 'InternVLChatModel' in str (type (self .model .model )):
1527+ self .image_token_id = 151667
1528+ self .mm_tokens_per_image = self .model .model .num_image_token
1529+ self .model .model .vision_buckets = VisionBuckets (self .model .model )
15041530
15051531 def _prepare_prompt (
15061532 self ,
@@ -1631,7 +1657,6 @@ def _prepare_prompt(
16311657 for idx in range (3 ):
16321658 seq_data_mrope_positions [idx ] \
16331659 .extend (mrope_positions [idx ])
1634-
16351660 multi_modal_kwargs_list .append (mm_kwargs )
16361661
16371662 for modality , placeholder_map in placeholder_maps .items ():
@@ -2709,17 +2734,28 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
27092734 else :
27102735 s = self .model .model .config .vision_config .image_size
27112736 pixel_values = torch .randn ([img_args , 3 , s , s ])
2712- num_image_tokens = self .model .model .config .mm_tokens_per_image \
2713- * img_args
2714- multi_modal_data = {
2715- "pixel_values" : pixel_values ,
2716- "num_crops" : torch .zeros ([img_args ], dtype = torch .int32 )
2717- }
27182737
2719- image_token_id = self .get_model ().config .image_token_id
2720- prompt_token_ids_image = [image_token_id ] * num_image_tokens
2738+ if 'Gemma3ForConditionalGeneration' in str (type (self .model .model )):
2739+ multi_modal_data = {
2740+ "pixel_values" : pixel_values ,
2741+ "num_crops" : torch .zeros ([img_args ], dtype = torch .int32 ),
2742+ }
2743+ elif 'InternVLChatModel' in str (type (self .model .model )):
2744+ multi_modal_data = {
2745+ "pixel_values_flat" :
2746+ pixel_values .to (torch .bfloat16 ),
2747+ "image_num_patches" :
2748+ torch .tensor ([pixel_values .shape [0 ]], dtype = torch .int32 ),
2749+ "image_token_id" :
2750+ torch .tensor ([self .image_token_id ], dtype = torch .int64 ),
2751+ }
2752+ else :
2753+ logger .warning ("No support for other models yet" )
2754+ num_image_tokens = self .mm_tokens_per_image * img_args
2755+ prompt_token_ids_image = [self .image_token_id ] * num_image_tokens
27212756 prompt_token_ids = [0 ] * (
27222757 seq_len - len (prompt_token_ids_image )) + prompt_token_ids_image
2758+
27232759 prompt_token_ids_array = array ('l' , prompt_token_ids ) # noqa: F821
27242760 placeholders_by_modality = {
27252761 'image' :
@@ -3188,9 +3224,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
31883224 if graphs :
31893225 self .graphed_buckets .add (cfg )
31903226 if self .is_mm_run ():
3191- img_args = (int (seq_len ) //
3192- self .model .model .config .mm_tokens_per_image
3193- if self .is_mm_optimized else int (seq_len ))
3227+ img_args = int (seq_len ) // self .mm_tokens_per_image
31943228 self .warmup_scenario (
31953229 int (bs ),
31963230 int (seq_len ),
@@ -3539,7 +3573,7 @@ def _get_seq_ids(self, model_input):
35393573 def _get_img_args_from_model_input (self , model_input ):
35403574 if (not self .model_is_mrope and not self .is_mm_optimized ) or \
35413575 not model_input .multi_modal_kwargs or \
3542- 'pixel_values' not in model_input .multi_modal_kwargs :
3576+ ( 'pixel_values' ) not in model_input .multi_modal_kwargs :
35433577 return None
35443578 if self .model_is_mrope :
35453579 pixel_values_list = model_input .multi_modal_kwargs ['pixel_values' ]
@@ -3816,18 +3850,17 @@ def try_revert_dummy_output_tokens():
38163850 'real_seq_len' : model_input .seq_lens ,
38173851 'real_batch_size' : real_batch_size
38183852 }
3819-
38203853 #Need to set the window_slide mask at this point to decide
38213854 if is_prompt :
38223855 attn_metadata = self .model ._update_use_window_sdpa (
38233856 execute_model_kwargs ['attn_metadata' ], seq_len ,
38243857 bool (model_input .multi_modal_kwargs and \
3825- 'pixel_values' in model_input .multi_modal_kwargs ))
3858+ ( 'pixel_values' ) in model_input .multi_modal_kwargs ))
38263859 execute_model_kwargs ['attn_metadata' ] = attn_metadata
38273860
38283861 if not bypass_model_exec :
38293862 if self .model_is_mrope or self .is_mm_optimized :
3830- if 'pixel_values' in execute_model_kwargs and \
3863+ if ( 'pixel_values' ) in execute_model_kwargs and \
38313864 self .is_mm_optimized :
38323865 if warmup_mode and not is_pt_profiler_run :
38333866 bypass_model_exec = True
0 commit comments