Skip to content

Commit bb96123

Browse files
yeonsilylibinta
andauthored
Merge Libint/intervl_bucket (#1965)
## Essential Elements of an Effective PR Description Checklist - [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". - [ ] The test plan, such as providing test command. - [ ] The test results, such as pasting the results comparison before and after, or e2e results ## Purpose ## Test Plan ## Test Result <!--- pyml disable-next-line no-emphasis-as-heading --> --------- Signed-off-by: Libin Tang <[email protected]> Co-authored-by: Libin Tang <[email protected]> Co-authored-by: Libin Tang <[email protected]>
1 parent ee517a2 commit bb96123

File tree

3 files changed

+148
-53
lines changed

3 files changed

+148
-53
lines changed

vllm/model_executor/models/gemma3_mm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ def _process_image_input(
577577

578578
for i in batch_breakdown:
579579
end_idx = start_idx + i
580-
indices = torch.arange(start_idx, end_idx)
580+
indices = torch.arange(start_idx,
581+
end_idx).to(pixel_values.device)
581582
batch_sliced_pixel_values = torch.index_select(pixel_values,
582583
dim=0,
583584
index=indices)

vllm/model_executor/models/internvl.py

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Copyright (c) 2023 OpenGVLab
88
# Licensed under The MIT License [see LICENSE for details]
99
# --------------------------------------------------------
10+
import os
1011
from abc import ABC, abstractmethod
1112
from collections.abc import Iterable, Mapping, Sequence
1213
from typing import Any, Literal, Optional, TypedDict, TypeVar, Union
@@ -35,13 +36,15 @@
3536
BaseProcessingInfo, PromptReplacement,
3637
PromptUpdate, PromptUpdateDetails)
3738
from vllm.multimodal.profiling import BaseDummyInputsBuilder
39+
from vllm.platforms import current_platform
3840
from vllm.sequence import IntermediateTensors
3941
from vllm.transformers_utils.tokenizer import AnyTokenizer
4042

4143
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
4244
SupportsMultiModal, SupportsPP)
43-
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
44-
maybe_prefix, merge_multimodal_embeddings)
45+
from .utils import (AutoWeightsLoader, flatten_bn, greedy_plan,
46+
init_vllm_registered_model, maybe_prefix,
47+
merge_multimodal_embeddings)
4548

4649
IMG_START = '<img>'
4750
IMG_END = '</img>'
@@ -50,6 +53,9 @@
5053
IMAGENET_MEAN = (0.485, 0.456, 0.406)
5154
IMAGENET_STD = (0.229, 0.224, 0.225)
5255

56+
is_hpu = current_platform.is_hpu()
57+
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '0') == '1' if is_hpu else False
58+
5359

5460
class InternVLImagePixelInputs(TypedDict):
5561
type: Literal["pixel_values"]
@@ -1062,6 +1068,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
10621068
self.visual_token_mask = None
10631069
self.make_empty_intermediate_tensors = (
10641070
self.language_model.make_empty_intermediate_tensors)
1071+
if is_hpu:
1072+
self.graphed_multimodal_buckets = None
10651073

10661074
def _patch_quant_config(self, config: PretrainedConfig,
10671075
quant_config: QuantizationConfig):
@@ -1127,16 +1135,64 @@ def pixel_shuffle(self, x, scale_factor=0.5):
11271135
return x
11281136

11291137
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
1130-
vit_embeds = self.vision_model(pixel_values=pixel_values)
1131-
vit_embeds = vit_embeds[:, 1:, :]
1132-
1133-
h = w = int(vit_embeds.shape[1]**0.5)
1134-
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
1135-
vit_embeds = self.pixel_shuffle(vit_embeds,
1136-
scale_factor=self.downsample_ratio)
1137-
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
1138-
vit_embeds.shape[-1])
1139-
vit_embeds = self.mlp1(vit_embeds)
1138+
if is_hpu:
1139+
if self.vision_buckets.multimodal_buckets:
1140+
batch_breakdown = greedy_plan(pixel_values.shape[0], \
1141+
self.vision_buckets.multimodal_buckets)
1142+
else:
1143+
batch_breakdown = [pixel_values.shape[0]]
1144+
1145+
start_idx = 0
1146+
vit_embeds_minibatches = []
1147+
1148+
for i in batch_breakdown:
1149+
end_idx = start_idx + i
1150+
batch_sliced_pixel_values = \
1151+
pixel_values[start_idx:end_idx, ...]
1152+
if is_lazy:
1153+
vit_embeds_minibatch = \
1154+
self.vision_model(
1155+
pixel_values=batch_sliced_pixel_values,
1156+
bypass_hpu_graphs=i
1157+
not in self.graphed_multimodal_buckets
1158+
and len(self.graphed_multimodal_buckets) > 0)
1159+
else:
1160+
vit_embeds_minibatch = \
1161+
self.vision_model(
1162+
pixel_values=batch_sliced_pixel_values)
1163+
1164+
vit_embeds_minibatch = vit_embeds_minibatch[:, 1:, :]
1165+
1166+
h = w = int(vit_embeds_minibatch.shape[1]**0.5)
1167+
vit_embeds_minibatch = vit_embeds_minibatch.reshape(
1168+
vit_embeds_minibatch.shape[0], h, w, -1)
1169+
vit_embeds_minibatch = self.pixel_shuffle(
1170+
vit_embeds_minibatch, scale_factor=self.downsample_ratio)
1171+
vit_embeds_minibatch = vit_embeds_minibatch.reshape(
1172+
vit_embeds_minibatch.shape[0], -1,
1173+
vit_embeds_minibatch.shape[-1])
1174+
1175+
if is_lazy:
1176+
vit_embeds_minibatches += [
1177+
self.mlp1(vit_embeds_minibatch,
1178+
bypass_hpu_graphs=i
1179+
not in self.graphed_multimodal_buckets
1180+
and len(self.graphed_multimodal_buckets) > 0)
1181+
]
1182+
else:
1183+
vit_embeds_minibatches += [self.mlp1(vit_embeds_minibatch)]
1184+
start_idx = end_idx
1185+
vit_embeds = torch.cat(vit_embeds_minibatches, dim=0)
1186+
else:
1187+
vit_embeds = self.vision_model(pixel_values=pixel_values)
1188+
vit_embeds = vit_embeds[:, 1:, :]
1189+
h = w = int(vit_embeds.shape[1]**0.5)
1190+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
1191+
vit_embeds = self.pixel_shuffle(vit_embeds,
1192+
scale_factor=self.downsample_ratio)
1193+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
1194+
vit_embeds.shape[-1])
1195+
vit_embeds = self.mlp1(vit_embeds)
11401196
return vit_embeds
11411197

11421198
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
@@ -1180,8 +1236,11 @@ def _parse_and_validate_image_input(
11801236

11811237
image_token_id = kwargs["image_token_id"]
11821238
assert isinstance(image_token_id, torch.Tensor)
1183-
self.img_context_token_id = image_token_id.flatten().unique().item()
1184-
1239+
if is_hpu:
1240+
self.img_context_token_id = image_token_id.flatten()
1241+
else:
1242+
self.img_context_token_id = image_token_id.flatten().unique().item(
1243+
)
11851244
if pixel_values_flat is not None:
11861245
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
11871246
raise ValueError("Incorrect type of pixel values. "
@@ -1306,7 +1365,9 @@ def get_language_model(self) -> torch.nn.Module:
13061365

13071366
def get_multimodal_embeddings(
13081367
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
1309-
1368+
if is_hpu:
1369+
self.graphed_multimodal_buckets = kwargs.pop(
1370+
'graphed_multimodal_buckets', [])
13101371
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
13111372
if not modalities:
13121373
return None

vllm/worker/hpu_model_runner.py

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,22 @@ class VisionBuckets:
112112
This class is used to bucket image tokens
113113
'''
114114

115-
def __init__(self, is_batch_based):
116-
self.is_batch_based = is_batch_based
115+
def __init__(self, model):
116+
self.is_batch_based = True
117117
envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "")
118118
if envvar == 'None':
119119
self.multimodal_buckets = None
120120
else:
121121
if envvar == "":
122-
if is_batch_based:
122+
if 'InternVLChatModel' in str(type(model)):
123+
multimodal_buckets = list(
124+
range(model.config.min_dynamic_patch,
125+
model.config.max_dynamic_patch +
126+
2)) #As use_thumbnail is true
127+
elif 'Gemma3ForConditionalGeneration' in str(type(model)):
123128
multimodal_buckets = [1, 2, 4, 8] # batch sizes for gemma3
124129
else:
130+
self.is_batch_based = False
125131
multimodal_buckets = [
126132
1600, 3136, 4096, 6400, 7744, 9216, 12544
127133
]
@@ -159,9 +165,11 @@ def __call__(cls, *args, **kwargs):
159165

160166

161167
def is_mm_optimized(model):
162-
return 'Gemma3ForConditionalGeneration' in str(type(model.model)) \
163-
if hasattr(model, 'model') else \
164-
'Gemma3ForConditionalGeneration' in str(type(model))
168+
mm_models = ['Gemma3ForConditionalGeneration', 'InternVLChatModel']
169+
170+
return any(m in str(type(model.model)) for m in mm_models) \
171+
if hasattr(model, 'model') \
172+
else any(m in str(type(model)) for m in mm_models)
165173

166174

167175
def pad_flat_tensor(tensor, desired_size):
@@ -345,6 +353,7 @@ def __init__(self, model, vllm_config, is_causal, sampler):
345353
model_config = getattr(self.model, "config", None)
346354

347355
self.model_is_mrope = uses_mrope(model_config)
356+
348357
self.is_mm_optimized = is_mm_optimized(self.model)
349358
text_config = vllm_config.model_config.hf_config.get_text_config()
350359
self.interleaved_sliding_window = getattr(
@@ -379,6 +388,12 @@ def __init__(self, model, vllm_config, is_causal, sampler):
379388
htorch.hpu.wrap_in_hpu_graph( \
380389
self.model.multi_modal_projector, \
381390
disable_tensor_cache=True)
391+
if hasattr(self.model, 'vision_model'):
392+
self.model.vision_model = htorch.hpu.wrap_in_hpu_graph(
393+
self.model.vision_model, disable_tensor_cache=True)
394+
if hasattr(self.model, 'mlp1'):
395+
self.model.mlp1 = htorch.hpu.wrap_in_hpu_graph(
396+
self.model.mlp1, disable_tensor_cache=True)
382397

383398
self._rotary_embed_module = self._get_rotary_embedding_module(
384399
self.model)
@@ -624,26 +639,30 @@ def compute_input_embeddings_for_mm_optimized(self, warmup_mode, **kwargs):
624639
vision_embeddings = self.model.get_multimodal_embeddings(**kwargs)
625640
inputs_embeds = self.model.get_input_embeddings(
626641
input_ids, vision_embeddings)
627-
628642
# TODO: In warmup, we need to warmup the model with dummy image data for
629643
# multimodal model for prompt, here instead of generating a dummy image,
630644
# we are just generating attn_mask for the images and pass with
631645
# attn_metadata, so we can reuse HPU graph without running
632646
# the whole vision tower.
633647
if vision_embeddings is not None or (
634648
warmup_mode and kwargs['attn_metadata'].is_prompt):
635-
input_ids = kwargs['input_ids']
636-
positions = kwargs['positions']
637-
kwargs = self.model.prepare_attn_masks(
638-
mask_dtype=self.dtype,
639-
**kwargs,
640-
)
641-
kwargs['input_ids'] = input_ids
642-
kwargs['positions'] = positions
649+
if hasattr(self.model, 'prepare_attn_masks'):
650+
input_ids = kwargs['input_ids']
651+
positions = kwargs['positions']
652+
kwargs = self.model.prepare_attn_masks(
653+
mask_dtype=self.dtype,
654+
**kwargs,
655+
)
656+
kwargs['input_ids'] = input_ids
657+
kwargs['positions'] = positions
658+
# done compute the visual tokens
659+
kwargs.pop('pixel_values', None)
660+
else:
661+
kwargs.pop('pixel_values_flat', None)
662+
kwargs.pop("image_num_patches", None)
663+
kwargs.pop("image_token_id", None)
643664

644665
kwargs.update({'inputs_embeds': inputs_embeds})
645-
# done compute the visual tokens and others
646-
kwargs.pop('pixel_values', None)
647666
kwargs.pop("num_crops", None)
648667
kwargs.pop("graphed_multimodal_buckets", None)
649668
return kwargs
@@ -699,7 +718,6 @@ def forward(self, *args, **kwargs):
699718
virtual_engine = 0
700719
if 'virtual_engine' in kwargs:
701720
virtual_engine = kwargs.pop('virtual_engine')
702-
703721
input_ids = kwargs['input_ids']
704722
global_attn_masks = kwargs.pop("global_attn_masks") \
705723
if kwargs.get("global_attn_masks") else None
@@ -1080,6 +1098,8 @@ def __init__(
10801098
and not self.lora_config)
10811099
self.use_delayed_sampling = get_config(
10821100
).use_delayed_sampling and can_use_delayed_sampling
1101+
self.mm_tokens_per_image = 1
1102+
self.image_token_id = 0
10831103

10841104
def _set_gc_threshold(self) -> None:
10851105
"""
@@ -1497,10 +1517,16 @@ def move_to_device(self, tensor):
14971517
non_blocking=True)
14981518

14991519
def add_vision_buckets_to_mrope_mm_optimized(self):
1500-
model = self.get_model()
1501-
self.is_mm_optimized = is_mm_optimized(model)
1520+
self.is_mm_optimized = is_mm_optimized(self.model)
15021521
if self.model_is_mrope or self.is_mm_optimized:
1503-
model.vision_buckets = VisionBuckets(self.is_mm_optimized)
1522+
if hasattr(self.model.model.config, 'mm_tokens_per_image'):
1523+
self.mm_tokens_per_image = \
1524+
self.model.model.config.mm_tokens_per_image
1525+
self.image_token_id = self.model.model.config.image_token_id
1526+
elif 'InternVLChatModel' in str(type(self.model.model)):
1527+
self.image_token_id = 151667
1528+
self.mm_tokens_per_image = self.model.model.num_image_token
1529+
self.model.model.vision_buckets = VisionBuckets(self.model.model)
15041530

15051531
def _prepare_prompt(
15061532
self,
@@ -1631,7 +1657,6 @@ def _prepare_prompt(
16311657
for idx in range(3):
16321658
seq_data_mrope_positions[idx] \
16331659
.extend(mrope_positions[idx])
1634-
16351660
multi_modal_kwargs_list.append(mm_kwargs)
16361661

16371662
for modality, placeholder_map in placeholder_maps.items():
@@ -2709,17 +2734,28 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
27092734
else:
27102735
s = self.model.model.config.vision_config.image_size
27112736
pixel_values = torch.randn([img_args, 3, s, s])
2712-
num_image_tokens = self.model.model.config.mm_tokens_per_image \
2713-
* img_args
2714-
multi_modal_data = {
2715-
"pixel_values": pixel_values,
2716-
"num_crops": torch.zeros([img_args], dtype=torch.int32)
2717-
}
27182737

2719-
image_token_id = self.get_model().config.image_token_id
2720-
prompt_token_ids_image = [image_token_id] * num_image_tokens
2738+
if 'Gemma3ForConditionalGeneration' in str(type(self.model.model)):
2739+
multi_modal_data = {
2740+
"pixel_values": pixel_values,
2741+
"num_crops": torch.zeros([img_args], dtype=torch.int32),
2742+
}
2743+
elif 'InternVLChatModel' in str(type(self.model.model)):
2744+
multi_modal_data = {
2745+
"pixel_values_flat":
2746+
pixel_values.to(torch.bfloat16),
2747+
"image_num_patches":
2748+
torch.tensor([pixel_values.shape[0]], dtype=torch.int32),
2749+
"image_token_id":
2750+
torch.tensor([self.image_token_id], dtype=torch.int64),
2751+
}
2752+
else:
2753+
logger.warning("No support for other models yet")
2754+
num_image_tokens = self.mm_tokens_per_image * img_args
2755+
prompt_token_ids_image = [self.image_token_id] * num_image_tokens
27212756
prompt_token_ids = [0] * (
27222757
seq_len - len(prompt_token_ids_image)) + prompt_token_ids_image
2758+
27232759
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
27242760
placeholders_by_modality = {
27252761
'image':
@@ -3188,9 +3224,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
31883224
if graphs:
31893225
self.graphed_buckets.add(cfg)
31903226
if self.is_mm_run():
3191-
img_args = (int(seq_len) //
3192-
self.model.model.config.mm_tokens_per_image
3193-
if self.is_mm_optimized else int(seq_len))
3227+
img_args = int(seq_len) // self.mm_tokens_per_image
31943228
self.warmup_scenario(
31953229
int(bs),
31963230
int(seq_len),
@@ -3539,7 +3573,7 @@ def _get_seq_ids(self, model_input):
35393573
def _get_img_args_from_model_input(self, model_input):
35403574
if (not self.model_is_mrope and not self.is_mm_optimized) or \
35413575
not model_input.multi_modal_kwargs or \
3542-
'pixel_values' not in model_input.multi_modal_kwargs:
3576+
('pixel_values') not in model_input.multi_modal_kwargs:
35433577
return None
35443578
if self.model_is_mrope:
35453579
pixel_values_list = model_input.multi_modal_kwargs['pixel_values']
@@ -3816,18 +3850,17 @@ def try_revert_dummy_output_tokens():
38163850
'real_seq_len': model_input.seq_lens,
38173851
'real_batch_size': real_batch_size
38183852
}
3819-
38203853
#Need to set the window_slide mask at this point to decide
38213854
if is_prompt:
38223855
attn_metadata = self.model._update_use_window_sdpa(
38233856
execute_model_kwargs['attn_metadata'], seq_len,
38243857
bool(model_input.multi_modal_kwargs and \
3825-
'pixel_values' in model_input.multi_modal_kwargs))
3858+
('pixel_values')in model_input.multi_modal_kwargs))
38263859
execute_model_kwargs['attn_metadata'] = attn_metadata
38273860

38283861
if not bypass_model_exec:
38293862
if self.model_is_mrope or self.is_mm_optimized:
3830-
if 'pixel_values' in execute_model_kwargs and \
3863+
if ('pixel_values') in execute_model_kwargs and \
38313864
self.is_mm_optimized:
38323865
if warmup_mode and not is_pt_profiler_run:
38333866
bypass_model_exec = True

0 commit comments

Comments
 (0)