Skip to content

Commit 9fe5f47

Browse files
committed
fix param
1 parent 3bfdb15 commit 9fe5f47

File tree

1 file changed

+111
-1
lines changed

1 file changed

+111
-1
lines changed

vllm_ascend/models/qwen2_5_vl_without_padding.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@
6565
Qwen3VLProcessingInfo = object
6666
Qwen3VLMoeForConditionalGeneration = object
6767
Qwen3VLMoeProcessingInfo = object
68-
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
68+
from vllm.model_executor.models.utils import (WeightsMapper,
69+
_merge_multimodal_embeddings,
70+
maybe_prefix)
6971
from vllm.multimodal import MULTIMODAL_REGISTRY
7072

7173
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
@@ -669,3 +671,111 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
669671
prefix=maybe_prefix(prefix, "visual"),
670672
use_data_parallel=self.use_data_parallel,
671673
)
674+
def _get_text_embeddings(
675+
self,
676+
input_ids: torch.Tensor,
677+
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
678+
*,
679+
is_multimodal: Optional[torch.Tensor],
680+
handle_oov_mm_token: bool,
681+
) -> torch.Tensor:
682+
if handle_oov_mm_token and is_multimodal is not None:
683+
is_text = ~is_multimodal
684+
text_embeds = get_input_embeddings(input_ids[is_text])
685+
return torch.empty(
686+
(input_ids.shape[0], text_embeds.shape[1]),
687+
dtype=text_embeds.dtype,
688+
device=text_embeds.device,
689+
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
690+
return get_input_embeddings(input_ids)
691+
692+
def get_input_embeddings(
693+
self,
694+
input_ids: torch.Tensor,
695+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
696+
*,
697+
is_multimodal: Optional[torch.Tensor] = None,
698+
handle_oov_mm_token: bool = False,
699+
) -> torch.Tensor:
700+
"""
701+
Apply token embeddings to `input_ids`.
702+
If `multimodal_embeddings` is passed, scatter them into
703+
`input_ids` according to the mask `is_multimodal`.
704+
In case the multi-modal token IDs exceed the vocabulary size of
705+
the language model, you can set `handle_oov_mm_token=False`
706+
to avoid calling the language model's `get_input_embeddings` method
707+
on those tokens. Note however that doing so increases memory usage
708+
as an additional buffer is needed to hold the input embeddings.
709+
"""
710+
inputs_embeds = self._get_text_embeddings(
711+
input_ids,
712+
self.get_language_model().get_input_embeddings,
713+
is_multimodal=is_multimodal,
714+
handle_oov_mm_token=handle_oov_mm_token,
715+
)
716+
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
717+
return inputs_embeds
718+
if is_multimodal is None:
719+
raise ValueError(
720+
"`get_input_embeddings` now requires `is_multimodal` arg, "
721+
"please update your model runner according to "
722+
"https://github.com/vllm-project/vllm/pull/16229.")
723+
if self.use_deepstack:
724+
(
725+
deepstack_input_embeds,
726+
multimodal_embeddings,
727+
) = self._compute_deepstack_embeds(
728+
inputs_embeds=inputs_embeds,
729+
multimodal_embeddings=multimodal_embeddings,
730+
is_multimodal=is_multimodal,
731+
)
732+
else:
733+
deepstack_input_embeds = None
734+
inputs_embeds = _merge_multimodal_embeddings(
735+
inputs_embeds=inputs_embeds,
736+
is_multimodal=is_multimodal,
737+
multimodal_embeddings=multimodal_embeddings,
738+
)
739+
if deepstack_input_embeds is not None:
740+
self._set_deepstack_input_embeds(deepstack_input_embeds)
741+
return inputs_embeds
742+
743+
def _compute_deepstack_embeds(
744+
self,
745+
inputs_embeds: torch.Tensor,
746+
multimodal_embeddings: MultiModalEmbeddings,
747+
is_multimodal: torch.Tensor,
748+
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
749+
750+
visual_lens = [len(x) for x in multimodal_embeddings]
751+
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
752+
753+
total_dim = multimodal_embeddings_cat.shape[-1]
754+
assert total_dim == self.visual_dim + self.multiscale_dim, \
755+
f"Total dimension mismatch: input {total_dim}, expected {self.visual_dim + self.multiscale_dim}"
756+
multimodal_embeddings_main = multimodal_embeddings_cat[
757+
..., :self.visual_dim]
758+
multimodal_embeddings_multiscale = multimodal_embeddings_cat[
759+
..., self.visual_dim:]
760+
761+
multimodal_embeddings = torch.split(multimodal_embeddings_main,
762+
visual_lens,
763+
dim=0)
764+
multimodal_embeddings_multiscale = torch.split(
765+
multimodal_embeddings_multiscale, visual_lens, dim=0)
766+
767+
deepstack_input_embeds = inputs_embeds.new_zeros(
768+
inputs_embeds.size(0),
769+
self.deepstack_num_level * inputs_embeds.size(1))
770+
771+
deepstack_input_embeds = _merge_multimodal_embeddings(
772+
inputs_embeds=deepstack_input_embeds,
773+
multimodal_embeddings=multimodal_embeddings_multiscale,
774+
is_multimodal=is_multimodal,
775+
)
776+
deepstack_input_embeds = deepstack_input_embeds.view(
777+
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim)
778+
deepstack_input_embeds = deepstack_input_embeds.permute(
779+
1, 0, 2).contiguous()
780+
781+
return deepstack_input_embeds, multimodal_embeddings

0 commit comments

Comments
 (0)