@@ -373,7 +373,7 @@ def __init__(self, model, vllm_config, is_causal, sampler):
373373 if self .is_mm_optimized :
374374 if hasattr (self .model , 'vision_tower' ):
375375 self .model .vision_tower = htorch .hpu .wrap_in_hpu_graph (
376- self .model .vision_tower , disable_tensor_cache = True )
376+ self .model .vision_tower , disable_tensor_cache = False )
377377 if hasattr (self .model , 'multi_modal_projector' ):
378378 self .model .multi_modal_projector = \
379379 htorch .hpu .wrap_in_hpu_graph ( \
@@ -619,13 +619,19 @@ def _update_metadata(self,
619619 device , dtype , True )
620620 return attn_metadata
621621
622- def compute_input_embeddings_for_mm_optimized (self , ** kwargs ):
622+ def compute_input_embeddings_for_mm_optimized (self , warmup_mode , ** kwargs ):
623623 input_ids = kwargs ['input_ids' ]
624624 vision_embeddings = self .model .get_multimodal_embeddings (** kwargs )
625625 inputs_embeds = self .model .get_input_embeddings (
626626 input_ids , vision_embeddings )
627627
628- if vision_embeddings is not None :
628+ # TODO: In warmup, we need to warmup the model with dummy image data for
629+ # multimodal model for prompt, here instead of generating a dummy image,
630+ # we are just generating attn_mask for the images and pass with
631+ # attn_metadata, so we can reuse HPU graph without running
632+ # the whole vision tower.
633+ if vision_embeddings is not None or (
634+ warmup_mode and kwargs ['attn_metadata' ].is_prompt ):
629635 input_ids = kwargs ['input_ids' ]
630636 positions = kwargs ['positions' ]
631637 kwargs = self .model .prepare_attn_masks (
@@ -634,14 +640,16 @@ def compute_input_embeddings_for_mm_optimized(self, **kwargs):
634640 )
635641 kwargs ['input_ids' ] = input_ids
636642 kwargs ['positions' ] = positions
637- #input_ids = None
638643
639644 kwargs .update ({'inputs_embeds' : inputs_embeds })
640- # done compute the visual tokens
645+ # done compute the visual tokens and others
641646 kwargs .pop ('pixel_values' , None )
647+ kwargs .pop ("num_crops" , None )
648+ kwargs .pop ("graphed_multimodal_buckets" , None )
642649 return kwargs
643650
644- def compute_input_embeddings_for_mrope_mm_optimized (self , ** kwargs ):
651+ def compute_input_embeddings_for_mrope_mm_optimized (
652+ self , warmup_mode , ** kwargs ):
645653
646654 if 'inputs_embeds' in kwargs :
647655 return kwargs
@@ -680,7 +688,8 @@ def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
680688 kwargs .pop ('image_grid_thw' , None )
681689 return kwargs
682690 else :
683- return self .compute_input_embeddings_for_mm_optimized (** kwargs )
691+ return self .compute_input_embeddings_for_mm_optimized (
692+ warmup_mode , ** kwargs )
684693
685694 def forward (self , * args , ** kwargs ):
686695 kwargs = kwargs .copy ()
@@ -692,9 +701,9 @@ def forward(self, *args, **kwargs):
692701 virtual_engine = kwargs .pop ('virtual_engine' )
693702
694703 input_ids = kwargs ['input_ids' ]
695- global_attn_masks = kwargs .get ("global_attn_masks" ) \
704+ global_attn_masks = kwargs .pop ("global_attn_masks" ) \
696705 if kwargs .get ("global_attn_masks" ) else None
697- local_attn_masks = kwargs .get ("local_attn_masks" ) \
706+ local_attn_masks = kwargs .pop ("local_attn_masks" ) \
698707 if kwargs .get ("local_attn_masks" ) else None
699708
700709 kwargs ['attn_metadata' ] = self ._update_metadata (
@@ -1396,12 +1405,8 @@ def get_model(self) -> torch.nn.Module:
13961405 return self .model .model
13971406 return self .model
13981407
1399- def _use_graphs (self , img_args = None ):
1400- if not img_args :
1401- return not self .enforce_eager
1402- #TODO: We might need to check both language bucket and multimodal bucket
1403- # and return True only it's avialble, or return separately.
1404- return (img_args ) in self .graphed_multimodal_buckets
1408+ def _use_graphs (self ):
1409+ return not self .enforce_eager
14051410
14061411 def _is_valid_bucket (self , bucket ):
14071412 return bucket [0 ] * bucket [1 ] <= self .max_num_batched_tokens
@@ -2667,7 +2672,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
26672672
26682673 def create_dummy_multi_modal_seq_group_metadata (self , group_id , img_args ,
26692674 sampling_params ,
2670- lora_request ):
2675+ lora_request , seq_len ):
26712676 assert self .model_is_mrope or self .is_mm_optimized , \
26722677 ("Warmup compatible with Qwen2vl/Gemma3 models" )
26732678 if img_args == UNSET_IMG_ARGS :
@@ -2712,7 +2717,9 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
27122717 }
27132718
27142719 image_token_id = self .get_model ().config .image_token_id
2715- prompt_token_ids = [image_token_id ] * num_image_tokens
2720+ prompt_token_ids_image = [image_token_id ] * num_image_tokens
2721+ prompt_token_ids = [0 ] * (
2722+ seq_len - len (prompt_token_ids_image )) + prompt_token_ids_image
27162723 prompt_token_ids_array = array ('l' , prompt_token_ids ) # noqa: F821
27172724 placeholders_by_modality = {
27182725 'image' :
@@ -2756,6 +2763,7 @@ def create_dummy_seq_group_metadata(self,
27562763 img_args = img_args ,
27572764 sampling_params = sampling_params ,
27582765 lora_request = lora_request ,
2766+ seq_len = seq_len ,
27592767 )
27602768 else :
27612769 input_len = seq_len
@@ -2867,7 +2875,7 @@ def warmup_scenario(self,
28672875 align_worker = False ,
28682876 is_dummy_run = False ) -> None :
28692877 phase = 'prompt' if is_prompt else 'decode'
2870- use_graphs = is_dummy_run or self ._use_graphs (img_args )
2878+ use_graphs = is_dummy_run or self ._use_graphs ()
28712879
28722880 scenario_name = ("warmup_"
28732881 f"{ phase } _"
@@ -3664,8 +3672,7 @@ def execute_model(
36643672 if not warmup_mode :
36653673 ctx_blocks = seq_len
36663674 seq_len = 1
3667- img_args = self ._get_img_args_from_model_input (model_input )
3668- use_graphs = self ._use_graphs (img_args = img_args )
3675+ use_graphs = self ._use_graphs ()
36693676 self ._check_config (batch_size , seq_len , ctx_blocks , attn_metadata ,
36703677 warmup_mode )
36713678 lora_mask : torch .Tensor = None
@@ -3831,6 +3838,7 @@ def try_revert_dummy_output_tokens():
38313838 # hpu graphs, hence turning it to a list
38323839 execute_model_kwargs = \
38333840 self .model .compute_input_embeddings_for_mrope_mm_optimized (
3841+ warmup_mode ,
38343842 ** execute_model_kwargs
38353843 )
38363844 if warmup_mode and bypass_model_exec :
0 commit comments