Skip to content

Commit 86c6f70

Browse files
Fixes HPU graph run for Gemma3 vision inputs (#1865)
Fixes HPU graph issues for gemma3 vision inputs Text warmup to include attn_mask info, so vision+text data can reuse the graph for language model that's warmed up already. Changing slicing to index_select for multimodal bucketing for HPU. Slicing doesn't produce the same hash for the HPU graph with same input shape. Use buckets for the vision tower as well to reduce GC recompile Accuracy bug fix by clone output data of the multimodal-projector. Validated with Muirbench datasets.
1 parent 6a1d7ad commit 86c6f70

File tree

3 files changed

+44
-39
lines changed

3 files changed

+44
-39
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model_name: "/mnt/weka/data/pytorch/Qwen/Qwen2.5-VL-7B-Instruct/"
22
dtype: "bfloat16"
3-
max_model_len: 32768
3+
max_model_len: 35840
44
max_num_seqs: 32
55
num_prompts: 4

vllm/model_executor/models/gemma3_mm.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,6 @@ def _process_image_input(
569569
pixel_values = image_input["pixel_values"]
570570
num_patches = image_input["num_patches"]
571571

572-
image_features = self._image_pixels_to_features(
573-
self.vision_tower,
574-
pixel_values,
575-
)
576-
577572
if is_hpu:
578573
batch_breakdown = greedy_plan(pixel_values.shape[0], \
579574
self.vision_buckets.multimodal_buckets)
@@ -582,22 +577,24 @@ def _process_image_input(
582577

583578
for i in batch_breakdown:
584579
end_idx = start_idx + i
585-
batch_sliced_image_features = \
586-
image_features[start_idx:end_idx, ...]
587-
if is_lazy:
588-
image_embeds_multibatches += \
589-
[self.multi_modal_projector(
590-
batch_sliced_image_features,
591-
bypass_hpu_graphs=i
592-
not in self.graphed_multimodal_buckets
593-
and len(self.graphed_multimodal_buckets) > 0)]
594-
else:
595-
image_embeds_multibatches += \
596-
[self.multi_modal_projector( \
597-
batch_sliced_image_features)]
580+
indices = torch.arange(start_idx, end_idx)
581+
batch_sliced_pixel_values = torch.index_select(pixel_values,
582+
dim=0,
583+
index=indices)
584+
585+
image_features = self._image_pixels_to_features(
586+
self.vision_tower,
587+
batch_sliced_pixel_values,
588+
)
589+
image_embeds = self.multi_modal_projector(image_features)
590+
image_embeds_multibatches += [image_embeds.clone()]
598591
start_idx = end_idx
599592
image_embeds = torch.cat(image_embeds_multibatches, dim=0)
600593
else:
594+
image_features = self._image_pixels_to_features(
595+
self.vision_tower,
596+
pixel_values,
597+
)
601598
image_embeds = self.multi_modal_projector(image_features)
602599
return [
603600
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())

vllm/worker/hpu_model_runner.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def __init__(self, model, vllm_config, is_causal, sampler):
374374
if self.is_mm_optimized:
375375
if hasattr(self.model, 'vision_tower'):
376376
self.model.vision_tower = htorch.hpu.wrap_in_hpu_graph(
377-
self.model.vision_tower, disable_tensor_cache=True)
377+
self.model.vision_tower, disable_tensor_cache=False)
378378
if hasattr(self.model, 'multi_modal_projector'):
379379
self.model.multi_modal_projector = \
380380
htorch.hpu.wrap_in_hpu_graph( \
@@ -620,13 +620,19 @@ def _update_metadata(self,
620620
device, dtype, True)
621621
return attn_metadata
622622

623-
def compute_input_embeddings_for_mm_optimized(self, **kwargs):
623+
def compute_input_embeddings_for_mm_optimized(self, warmup_mode, **kwargs):
624624
input_ids = kwargs['input_ids']
625625
vision_embeddings = self.model.get_multimodal_embeddings(**kwargs)
626626
inputs_embeds = self.model.get_input_embeddings(
627627
input_ids, vision_embeddings)
628628

629-
if vision_embeddings is not None:
629+
# TODO: In warmup, we need to warmup the model with dummy image data for
630+
# multimodal model for prompt, here instead of generating a dummy image,
631+
# we are just generating attn_mask for the images and pass with
632+
# attn_metadata, so we can reuse HPU graph without running
633+
# the whole vision tower.
634+
if vision_embeddings is not None or (
635+
warmup_mode and kwargs['attn_metadata'].is_prompt):
630636
input_ids = kwargs['input_ids']
631637
positions = kwargs['positions']
632638
kwargs = self.model.prepare_attn_masks(
@@ -635,14 +641,16 @@ def compute_input_embeddings_for_mm_optimized(self, **kwargs):
635641
)
636642
kwargs['input_ids'] = input_ids
637643
kwargs['positions'] = positions
638-
#input_ids = None
639644

640645
kwargs.update({'inputs_embeds': inputs_embeds})
641-
# done compute the visual tokens
646+
# done compute the visual tokens and others
642647
kwargs.pop('pixel_values', None)
648+
kwargs.pop("num_crops", None)
649+
kwargs.pop("graphed_multimodal_buckets", None)
643650
return kwargs
644651

645-
def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
652+
def compute_input_embeddings_for_mrope_mm_optimized(
653+
self, warmup_mode, **kwargs):
646654

647655
if 'inputs_embeds' in kwargs:
648656
return kwargs
@@ -681,7 +689,8 @@ def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
681689
kwargs.pop('image_grid_thw', None)
682690
return kwargs
683691
else:
684-
return self.compute_input_embeddings_for_mm_optimized(**kwargs)
692+
return self.compute_input_embeddings_for_mm_optimized(
693+
warmup_mode, **kwargs)
685694

686695
def forward(self, *args, **kwargs):
687696
kwargs = kwargs.copy()
@@ -693,9 +702,9 @@ def forward(self, *args, **kwargs):
693702
virtual_engine = kwargs.pop('virtual_engine')
694703

695704
input_ids = kwargs['input_ids']
696-
global_attn_masks = kwargs.get("global_attn_masks") \
705+
global_attn_masks = kwargs.pop("global_attn_masks") \
697706
if kwargs.get("global_attn_masks") else None
698-
local_attn_masks = kwargs.get("local_attn_masks") \
707+
local_attn_masks = kwargs.pop("local_attn_masks") \
699708
if kwargs.get("local_attn_masks") else None
700709

701710
kwargs['attn_metadata'] = self._update_metadata(
@@ -1397,12 +1406,8 @@ def get_model(self) -> torch.nn.Module:
13971406
return self.model.model
13981407
return self.model
13991408

1400-
def _use_graphs(self, img_args=None):
1401-
if not img_args:
1402-
return not self.enforce_eager
1403-
#TODO: We might need to check both language bucket and multimodal bucket
1404-
# and return True only it's avialble, or return separately.
1405-
return (img_args) in self.graphed_multimodal_buckets
1409+
def _use_graphs(self):
1410+
return not self.enforce_eager
14061411

14071412
def _is_valid_bucket(self, bucket):
14081413
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
@@ -2668,7 +2673,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
26682673

26692674
def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
26702675
sampling_params,
2671-
lora_request):
2676+
lora_request, seq_len):
26722677
assert self.model_is_mrope or self.is_mm_optimized, \
26732678
("Warmup compatible with Qwen2vl/Gemma3 models")
26742679
if img_args == UNSET_IMG_ARGS:
@@ -2713,7 +2718,9 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
27132718
}
27142719

27152720
image_token_id = self.get_model().config.image_token_id
2716-
prompt_token_ids = [image_token_id] * num_image_tokens
2721+
prompt_token_ids_image = [image_token_id] * num_image_tokens
2722+
prompt_token_ids = [0] * (
2723+
seq_len - len(prompt_token_ids_image)) + prompt_token_ids_image
27172724
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
27182725
placeholders_by_modality = {
27192726
'image':
@@ -2757,6 +2764,7 @@ def create_dummy_seq_group_metadata(self,
27572764
img_args=img_args,
27582765
sampling_params=sampling_params,
27592766
lora_request=lora_request,
2767+
seq_len=seq_len,
27602768
)
27612769
else:
27622770
input_len = seq_len
@@ -2868,7 +2876,7 @@ def warmup_scenario(self,
28682876
align_worker=False,
28692877
is_dummy_run=False) -> None:
28702878
phase = 'prompt' if is_prompt else 'decode'
2871-
use_graphs = is_dummy_run or self._use_graphs(img_args)
2879+
use_graphs = is_dummy_run or self._use_graphs()
28722880

28732881
scenario_name = ("warmup_"
28742882
f"{phase}_"
@@ -3665,8 +3673,7 @@ def execute_model(
36653673
if not warmup_mode:
36663674
ctx_blocks = seq_len
36673675
seq_len = 1
3668-
img_args = self._get_img_args_from_model_input(model_input)
3669-
use_graphs = self._use_graphs(img_args=img_args)
3676+
use_graphs = self._use_graphs()
36703677
self._check_config(batch_size, seq_len, ctx_blocks, attn_metadata,
36713678
warmup_mode)
36723679
lora_mask: torch.Tensor = None
@@ -3832,6 +3839,7 @@ def try_revert_dummy_output_tokens():
38323839
# hpu graphs, hence turning it to a list
38333840
execute_model_kwargs = \
38343841
self.model.compute_input_embeddings_for_mrope_mm_optimized(
3842+
warmup_mode,
38353843
**execute_model_kwargs
38363844
)
38373845
if warmup_mode and bypass_model_exec:

0 commit comments

Comments
 (0)