Skip to content

Commit 36eb2cf

Browse files
Fixes HPU graph run for Gemma3 vision inputs (#1865)
Fixes HPU graph issues for gemma3 vision inputs Text warmup to include attn_mask info, so vision+text data can reuse the graph for language model that's warmed up already. Changing slicing to index_select for multimodal bucketing for HPU. Slicing doesn't produce the same hash for the HPU graph with same input shape. Use buckets for the vision tower as well to reduce GC recompile Accuracy bug fix by clone output data of the multimodal-projector. Validated with Muirbench datasets.
1 parent f7d88c3 commit 36eb2cf

File tree

3 files changed

+44
-39
lines changed

3 files changed

+44
-39
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model_name: "/mnt/weka/data/pytorch/Qwen/Qwen2.5-VL-7B-Instruct/"
22
dtype: "bfloat16"
3-
max_model_len: 32768
3+
max_model_len: 35840
44
max_num_seqs: 32
55
num_prompts: 4

vllm/model_executor/models/gemma3_mm.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,6 @@ def _process_image_input(
569569
pixel_values = image_input["pixel_values"]
570570
num_patches = image_input["num_patches"]
571571

572-
image_features = self._image_pixels_to_features(
573-
self.vision_tower,
574-
pixel_values,
575-
)
576-
577572
if is_hpu:
578573
batch_breakdown = greedy_plan(pixel_values.shape[0], \
579574
self.vision_buckets.multimodal_buckets)
@@ -582,22 +577,24 @@ def _process_image_input(
582577

583578
for i in batch_breakdown:
584579
end_idx = start_idx + i
585-
batch_sliced_image_features = \
586-
image_features[start_idx:end_idx, ...]
587-
if is_lazy:
588-
image_embeds_multibatches += \
589-
[self.multi_modal_projector(
590-
batch_sliced_image_features,
591-
bypass_hpu_graphs=i
592-
not in self.graphed_multimodal_buckets
593-
and len(self.graphed_multimodal_buckets) > 0)]
594-
else:
595-
image_embeds_multibatches += \
596-
[self.multi_modal_projector( \
597-
batch_sliced_image_features)]
580+
indices = torch.arange(start_idx, end_idx)
581+
batch_sliced_pixel_values = torch.index_select(pixel_values,
582+
dim=0,
583+
index=indices)
584+
585+
image_features = self._image_pixels_to_features(
586+
self.vision_tower,
587+
batch_sliced_pixel_values,
588+
)
589+
image_embeds = self.multi_modal_projector(image_features)
590+
image_embeds_multibatches += [image_embeds.clone()]
598591
start_idx = end_idx
599592
image_embeds = torch.cat(image_embeds_multibatches, dim=0)
600593
else:
594+
image_features = self._image_pixels_to_features(
595+
self.vision_tower,
596+
pixel_values,
597+
)
601598
image_embeds = self.multi_modal_projector(image_features)
602599
return [
603600
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())

vllm/worker/hpu_model_runner.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def __init__(self, model, vllm_config, is_causal, sampler):
373373
if self.is_mm_optimized:
374374
if hasattr(self.model, 'vision_tower'):
375375
self.model.vision_tower = htorch.hpu.wrap_in_hpu_graph(
376-
self.model.vision_tower, disable_tensor_cache=True)
376+
self.model.vision_tower, disable_tensor_cache=False)
377377
if hasattr(self.model, 'multi_modal_projector'):
378378
self.model.multi_modal_projector = \
379379
htorch.hpu.wrap_in_hpu_graph( \
@@ -619,13 +619,19 @@ def _update_metadata(self,
619619
device, dtype, True)
620620
return attn_metadata
621621

622-
def compute_input_embeddings_for_mm_optimized(self, **kwargs):
622+
def compute_input_embeddings_for_mm_optimized(self, warmup_mode, **kwargs):
623623
input_ids = kwargs['input_ids']
624624
vision_embeddings = self.model.get_multimodal_embeddings(**kwargs)
625625
inputs_embeds = self.model.get_input_embeddings(
626626
input_ids, vision_embeddings)
627627

628-
if vision_embeddings is not None:
628+
# TODO: In warmup, we need to warmup the model with dummy image data for
629+
# multimodal model for prompt, here instead of generating a dummy image,
630+
# we are just generating attn_mask for the images and pass with
631+
# attn_metadata, so we can reuse HPU graph without running
632+
# the whole vision tower.
633+
if vision_embeddings is not None or (
634+
warmup_mode and kwargs['attn_metadata'].is_prompt):
629635
input_ids = kwargs['input_ids']
630636
positions = kwargs['positions']
631637
kwargs = self.model.prepare_attn_masks(
@@ -634,14 +640,16 @@ def compute_input_embeddings_for_mm_optimized(self, **kwargs):
634640
)
635641
kwargs['input_ids'] = input_ids
636642
kwargs['positions'] = positions
637-
#input_ids = None
638643

639644
kwargs.update({'inputs_embeds': inputs_embeds})
640-
# done compute the visual tokens
645+
# done compute the visual tokens and others
641646
kwargs.pop('pixel_values', None)
647+
kwargs.pop("num_crops", None)
648+
kwargs.pop("graphed_multimodal_buckets", None)
642649
return kwargs
643650

644-
def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
651+
def compute_input_embeddings_for_mrope_mm_optimized(
652+
self, warmup_mode, **kwargs):
645653

646654
if 'inputs_embeds' in kwargs:
647655
return kwargs
@@ -680,7 +688,8 @@ def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
680688
kwargs.pop('image_grid_thw', None)
681689
return kwargs
682690
else:
683-
return self.compute_input_embeddings_for_mm_optimized(**kwargs)
691+
return self.compute_input_embeddings_for_mm_optimized(
692+
warmup_mode, **kwargs)
684693

685694
def forward(self, *args, **kwargs):
686695
kwargs = kwargs.copy()
@@ -692,9 +701,9 @@ def forward(self, *args, **kwargs):
692701
virtual_engine = kwargs.pop('virtual_engine')
693702

694703
input_ids = kwargs['input_ids']
695-
global_attn_masks = kwargs.get("global_attn_masks") \
704+
global_attn_masks = kwargs.pop("global_attn_masks") \
696705
if kwargs.get("global_attn_masks") else None
697-
local_attn_masks = kwargs.get("local_attn_masks") \
706+
local_attn_masks = kwargs.pop("local_attn_masks") \
698707
if kwargs.get("local_attn_masks") else None
699708

700709
kwargs['attn_metadata'] = self._update_metadata(
@@ -1396,12 +1405,8 @@ def get_model(self) -> torch.nn.Module:
13961405
return self.model.model
13971406
return self.model
13981407

1399-
def _use_graphs(self, img_args=None):
1400-
if not img_args:
1401-
return not self.enforce_eager
1402-
#TODO: We might need to check both language bucket and multimodal bucket
1403-
# and return True only it's avialble, or return separately.
1404-
return (img_args) in self.graphed_multimodal_buckets
1408+
def _use_graphs(self):
1409+
return not self.enforce_eager
14051410

14061411
def _is_valid_bucket(self, bucket):
14071412
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
@@ -2667,7 +2672,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
26672672

26682673
def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
26692674
sampling_params,
2670-
lora_request):
2675+
lora_request, seq_len):
26712676
assert self.model_is_mrope or self.is_mm_optimized, \
26722677
("Warmup compatible with Qwen2vl/Gemma3 models")
26732678
if img_args == UNSET_IMG_ARGS:
@@ -2712,7 +2717,9 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
27122717
}
27132718

27142719
image_token_id = self.get_model().config.image_token_id
2715-
prompt_token_ids = [image_token_id] * num_image_tokens
2720+
prompt_token_ids_image = [image_token_id] * num_image_tokens
2721+
prompt_token_ids = [0] * (
2722+
seq_len - len(prompt_token_ids_image)) + prompt_token_ids_image
27162723
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
27172724
placeholders_by_modality = {
27182725
'image':
@@ -2756,6 +2763,7 @@ def create_dummy_seq_group_metadata(self,
27562763
img_args=img_args,
27572764
sampling_params=sampling_params,
27582765
lora_request=lora_request,
2766+
seq_len=seq_len,
27592767
)
27602768
else:
27612769
input_len = seq_len
@@ -2867,7 +2875,7 @@ def warmup_scenario(self,
28672875
align_worker=False,
28682876
is_dummy_run=False) -> None:
28692877
phase = 'prompt' if is_prompt else 'decode'
2870-
use_graphs = is_dummy_run or self._use_graphs(img_args)
2878+
use_graphs = is_dummy_run or self._use_graphs()
28712879

28722880
scenario_name = ("warmup_"
28732881
f"{phase}_"
@@ -3664,8 +3672,7 @@ def execute_model(
36643672
if not warmup_mode:
36653673
ctx_blocks = seq_len
36663674
seq_len = 1
3667-
img_args = self._get_img_args_from_model_input(model_input)
3668-
use_graphs = self._use_graphs(img_args=img_args)
3675+
use_graphs = self._use_graphs()
36693676
self._check_config(batch_size, seq_len, ctx_blocks, attn_metadata,
36703677
warmup_mode)
36713678
lora_mask: torch.Tensor = None
@@ -3831,6 +3838,7 @@ def try_revert_dummy_output_tokens():
38313838
# hpu graphs, hence turning it to a list
38323839
execute_model_kwargs = \
38333840
self.model.compute_input_embeddings_for_mrope_mm_optimized(
3841+
warmup_mode,
38343842
**execute_model_kwargs
38353843
)
38363844
if warmup_mode and bypass_model_exec:

0 commit comments

Comments
 (0)