Skip to content

Commit 30c226e

Browse files
authored
Modify merge_multimodal_embeddings to static (#1969)
1 parent 716f3fc commit 30c226e

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

vllm/model_executor/models/internvl.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
SupportsMultiModal, SupportsPP)
4545
from .utils import (AutoWeightsLoader, flatten_bn, greedy_plan,
4646
init_vllm_registered_model, maybe_prefix,
47-
merge_multimodal_embeddings)
47+
merge_multimodal_embeddings,
48+
merge_multimodal_embeddings_static)
4849

4950
IMG_START = '<img>'
5051
IMG_END = '</img>'
@@ -1390,6 +1391,21 @@ def get_multimodal_embeddings(
13901391

13911392
return multimodal_embeddings
13921393

1394+
def get_input_embeddings_hpu(
1395+
self,
1396+
input_ids: torch.Tensor,
1397+
image_index_tensor: torch.Tensor,
1398+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
1399+
) -> torch.Tensor:
1400+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1401+
if multimodal_embeddings is not None:
1402+
inputs_embeds = merge_multimodal_embeddings_static(
1403+
image_index_tensor,
1404+
inputs_embeds,
1405+
multimodal_embeddings,
1406+
)
1407+
return inputs_embeds
1408+
13931409
def get_input_embeddings(
13941410
self,
13951411
input_ids: torch.Tensor,

vllm/model_executor/models/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,22 @@ def merge_multimodal_embeddings_from_map(
433433
return inputs_embeds
434434

435435

436+
def merge_multimodal_embeddings_static(
437+
is_multimodal_index: torch.Tensor,
438+
inputs_embeds: torch.Tensor,
439+
multimodal_embeddings: NestedTensors,
440+
) -> torch.Tensor:
441+
assert current_platform.is_hpu(), ("Support HPU only")
442+
flattened = _flatten_embeddings(multimodal_embeddings)
443+
444+
inputs_embeds_s = inputs_embeds.shape
445+
inputs_embeds = inputs_embeds.view(inputs_embeds_s[0] * inputs_embeds_s[1],
446+
inputs_embeds_s[2])
447+
inputs_embeds = inputs_embeds.index_copy_(0, is_multimodal_index,
448+
flattened).view(inputs_embeds_s)
449+
return inputs_embeds
450+
451+
436452
def _merge_multimodal_embeddings(
437453
inputs_embeds: torch.Tensor,
438454
is_multimodal: torch.Tensor,

vllm/worker/hpu_model_runner.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,8 +637,14 @@ def _update_metadata(self,
637637
def compute_input_embeddings_for_mm_optimized(self, warmup_mode, **kwargs):
638638
input_ids = kwargs['input_ids']
639639
vision_embeddings = self.model.get_multimodal_embeddings(**kwargs)
640-
inputs_embeds = self.model.get_input_embeddings(
641-
input_ids, vision_embeddings)
640+
if 'image_index' in kwargs:
641+
inputs_embeds = self.model.get_input_embeddings_hpu(
642+
input_ids, kwargs['image_index'], vision_embeddings)
643+
kwargs.pop("image_index", None)
644+
else:
645+
inputs_embeds = self.model.get_input_embeddings(
646+
input_ids, vision_embeddings)
647+
642648
# TODO: In warmup, we need to warmup the model with dummy image data for
643649
# multimodal model for prompt, here instead of generating a dummy image,
644650
# we are just generating attn_mask for the images and pass with
@@ -1772,6 +1778,7 @@ def _prepare_prompt(
17721778
pad=0,
17731779
dtype=torch.long,
17741780
flat=self.use_merged_prefill)
1781+
image_index_tensor = None
17751782
if self.model_is_mrope:
17761783
input_positions = \
17771784
make_mrope_positions_tensor_with_pad(input_positions=input_positions,
@@ -1785,6 +1792,11 @@ def _prepare_prompt(
17851792
dtype=torch.long,
17861793
flat=self.use_merged_prefill)
17871794

1795+
if seq_group_metadata.multi_modal_data and self.is_mm_optimized and \
1796+
'InternVLChatModel' in str(type(self.model.model)):
1797+
is_image_flatten = (
1798+
input_tokens_tensor == self.image_token_id).flatten()
1799+
image_index_tensor = is_image_flatten.nonzero().squeeze(-1)
17881800
slot_mapping = make_cpu_tensor(slot_mapping,
17891801
max_len=max_prompt_len,
17901802
pad=_PAD_SLOT_ID,
@@ -1872,6 +1884,8 @@ def _prepare_prompt(
18721884
input_positions=input_positions,
18731885
)
18741886
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
1887+
if image_index_tensor is not None:
1888+
multi_modal_kwargs['image_index'] = image_index_tensor
18751889
multi_modal_kwargs = MultiModalKwargs.as_kwargs(multi_modal_kwargs,
18761890
device=self.device)
18771891

@@ -3872,6 +3886,12 @@ def try_revert_dummy_output_tokens():
38723886
('pixel_values')in model_input.multi_modal_kwargs))
38733887
execute_model_kwargs['attn_metadata'] = attn_metadata
38743888

3889+
if 'image_index' in model_input.multi_modal_kwargs:
3890+
execute_model_kwargs[
3891+
'image_index'] = model_input.multi_modal_kwargs[
3892+
'image_index']
3893+
model_input.multi_modal_kwargs.pop('image_index', None)
3894+
38753895
if not bypass_model_exec:
38763896
if self.model_is_mrope or self.is_mm_optimized:
38773897
if ('pixel_values') in execute_model_kwargs and \

0 commit comments

Comments
 (0)