@@ -374,7 +374,7 @@ def __init__(self, model, vllm_config, is_causal, sampler):
374374 if self .is_mm_optimized :
375375 if hasattr (self .model , 'vision_tower' ):
376376 self .model .vision_tower = htorch .hpu .wrap_in_hpu_graph (
377- self .model .vision_tower , disable_tensor_cache = True )
377+ self .model .vision_tower , disable_tensor_cache = False )
378378 if hasattr (self .model , 'multi_modal_projector' ):
379379 self .model .multi_modal_projector = \
380380 htorch .hpu .wrap_in_hpu_graph ( \
@@ -620,13 +620,19 @@ def _update_metadata(self,
620620 device , dtype , True )
621621 return attn_metadata
622622
623- def compute_input_embeddings_for_mm_optimized (self , ** kwargs ):
623+ def compute_input_embeddings_for_mm_optimized (self , warmup_mode , ** kwargs ):
624624 input_ids = kwargs ['input_ids' ]
625625 vision_embeddings = self .model .get_multimodal_embeddings (** kwargs )
626626 inputs_embeds = self .model .get_input_embeddings (
627627 input_ids , vision_embeddings )
628628
629- if vision_embeddings is not None :
629+ # TODO: In warmup, we need to warmup the model with dummy image data for
630+ # multimodal model for prompt, here instead of generating a dummy image,
631+ # we are just generating attn_mask for the images and pass with
632+ # attn_metadata, so we can reuse HPU graph without running
633+ # the whole vision tower.
634+ if vision_embeddings is not None or (
635+ warmup_mode and kwargs ['attn_metadata' ].is_prompt ):
630636 input_ids = kwargs ['input_ids' ]
631637 positions = kwargs ['positions' ]
632638 kwargs = self .model .prepare_attn_masks (
@@ -635,14 +641,16 @@ def compute_input_embeddings_for_mm_optimized(self, **kwargs):
635641 )
636642 kwargs ['input_ids' ] = input_ids
637643 kwargs ['positions' ] = positions
638- #input_ids = None
639644
640645 kwargs .update ({'inputs_embeds' : inputs_embeds })
641- # done compute the visual tokens
646+ # done compute the visual tokens and others
642647 kwargs .pop ('pixel_values' , None )
648+ kwargs .pop ("num_crops" , None )
649+ kwargs .pop ("graphed_multimodal_buckets" , None )
643650 return kwargs
644651
645- def compute_input_embeddings_for_mrope_mm_optimized (self , ** kwargs ):
652+ def compute_input_embeddings_for_mrope_mm_optimized (
653+ self , warmup_mode , ** kwargs ):
646654
647655 if 'inputs_embeds' in kwargs :
648656 return kwargs
@@ -681,7 +689,8 @@ def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
681689 kwargs .pop ('image_grid_thw' , None )
682690 return kwargs
683691 else :
684- return self .compute_input_embeddings_for_mm_optimized (** kwargs )
692+ return self .compute_input_embeddings_for_mm_optimized (
693+ warmup_mode , ** kwargs )
685694
686695 def forward (self , * args , ** kwargs ):
687696 kwargs = kwargs .copy ()
@@ -693,9 +702,9 @@ def forward(self, *args, **kwargs):
693702 virtual_engine = kwargs .pop ('virtual_engine' )
694703
695704 input_ids = kwargs ['input_ids' ]
696- global_attn_masks = kwargs .get ("global_attn_masks" ) \
705+ global_attn_masks = kwargs .pop ("global_attn_masks" ) \
697706 if kwargs .get ("global_attn_masks" ) else None
698- local_attn_masks = kwargs .get ("local_attn_masks" ) \
707+ local_attn_masks = kwargs .pop ("local_attn_masks" ) \
699708 if kwargs .get ("local_attn_masks" ) else None
700709
701710 kwargs ['attn_metadata' ] = self ._update_metadata (
@@ -1397,12 +1406,8 @@ def get_model(self) -> torch.nn.Module:
13971406 return self .model .model
13981407 return self .model
13991408
1400- def _use_graphs (self , img_args = None ):
1401- if not img_args :
1402- return not self .enforce_eager
1403- #TODO: We might need to check both language bucket and multimodal bucket
1404- # and return True only it's avialble, or return separately.
1405- return (img_args ) in self .graphed_multimodal_buckets
1409+ def _use_graphs (self ):
1410+ return not self .enforce_eager
14061411
14071412 def _is_valid_bucket (self , bucket ):
14081413 return bucket [0 ] * bucket [1 ] <= self .max_num_batched_tokens
@@ -2668,7 +2673,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
26682673
26692674 def create_dummy_multi_modal_seq_group_metadata (self , group_id , img_args ,
26702675 sampling_params ,
2671- lora_request ):
2676+ lora_request , seq_len ):
26722677 assert self .model_is_mrope or self .is_mm_optimized , \
26732678 ("Warmup compatible with Qwen2vl/Gemma3 models" )
26742679 if img_args == UNSET_IMG_ARGS :
@@ -2713,7 +2718,9 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
27132718 }
27142719
27152720 image_token_id = self .get_model ().config .image_token_id
2716- prompt_token_ids = [image_token_id ] * num_image_tokens
2721+ prompt_token_ids_image = [image_token_id ] * num_image_tokens
2722+ prompt_token_ids = [0 ] * (
2723+ seq_len - len (prompt_token_ids_image )) + prompt_token_ids_image
27172724 prompt_token_ids_array = array ('l' , prompt_token_ids ) # noqa: F821
27182725 placeholders_by_modality = {
27192726 'image' :
@@ -2757,6 +2764,7 @@ def create_dummy_seq_group_metadata(self,
27572764 img_args = img_args ,
27582765 sampling_params = sampling_params ,
27592766 lora_request = lora_request ,
2767+ seq_len = seq_len ,
27602768 )
27612769 else :
27622770 input_len = seq_len
@@ -2868,7 +2876,7 @@ def warmup_scenario(self,
28682876 align_worker = False ,
28692877 is_dummy_run = False ) -> None :
28702878 phase = 'prompt' if is_prompt else 'decode'
2871- use_graphs = is_dummy_run or self ._use_graphs (img_args )
2879+ use_graphs = is_dummy_run or self ._use_graphs ()
28722880
28732881 scenario_name = ("warmup_"
28742882 f"{ phase } _"
@@ -3665,8 +3673,7 @@ def execute_model(
36653673 if not warmup_mode :
36663674 ctx_blocks = seq_len
36673675 seq_len = 1
3668- img_args = self ._get_img_args_from_model_input (model_input )
3669- use_graphs = self ._use_graphs (img_args = img_args )
3676+ use_graphs = self ._use_graphs ()
36703677 self ._check_config (batch_size , seq_len , ctx_blocks , attn_metadata ,
36713678 warmup_mode )
36723679 lora_mask : torch .Tensor = None
@@ -3832,6 +3839,7 @@ def try_revert_dummy_output_tokens():
38323839 # hpu graphs, hence turning it to a list
38333840 execute_model_kwargs = \
38343841 self .model .compute_input_embeddings_for_mrope_mm_optimized (
3842+ warmup_mode ,
38353843 ** execute_model_kwargs
38363844 )
38373845 if warmup_mode and bypass_model_exec :
0 commit comments