|
65 | 65 | Qwen3VLProcessingInfo = object |
66 | 66 | Qwen3VLMoeForConditionalGeneration = object |
67 | 67 | Qwen3VLMoeProcessingInfo = object |
68 | | -from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix |
| 68 | +from vllm.model_executor.models.utils import (WeightsMapper, |
| 69 | + _merge_multimodal_embeddings, |
| 70 | + maybe_prefix) |
69 | 71 | from vllm.multimodal import MULTIMODAL_REGISTRY |
70 | 72 |
|
71 | 73 | from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding |
@@ -669,3 +671,112 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
669 | 671 | prefix=maybe_prefix(prefix, "visual"), |
670 | 672 | use_data_parallel=self.use_data_parallel, |
671 | 673 | ) |
| 674 | + |
| 675 | + def _get_text_embeddings( |
| 676 | + self, |
| 677 | + input_ids: torch.Tensor, |
| 678 | + get_input_embeddings: Callable[[torch.Tensor], torch.Tensor], |
| 679 | + *, |
| 680 | + is_multimodal: Optional[torch.Tensor], |
| 681 | + handle_oov_mm_token: bool, |
| 682 | + ) -> torch.Tensor: |
| 683 | + if handle_oov_mm_token and is_multimodal is not None: |
| 684 | + is_text = ~is_multimodal |
| 685 | + text_embeds = get_input_embeddings(input_ids[is_text]) |
| 686 | + return torch.empty( |
| 687 | + (input_ids.shape[0], text_embeds.shape[1]), |
| 688 | + dtype=text_embeds.dtype, |
| 689 | + device=text_embeds.device, |
| 690 | + ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) |
| 691 | + return get_input_embeddings(input_ids) |
| 692 | + |
| 693 | + def get_input_embeddings( |
| 694 | + self, |
| 695 | + input_ids: torch.Tensor, |
| 696 | + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, |
| 697 | + *, |
| 698 | + is_multimodal: Optional[torch.Tensor] = None, |
| 699 | + handle_oov_mm_token: bool = False, |
| 700 | + ) -> torch.Tensor: |
| 701 | + """ |
| 702 | + Apply token embeddings to `input_ids`. |
| 703 | + If `multimodal_embeddings` is passed, scatter them into |
| 704 | + `input_ids` according to the mask `is_multimodal`. |
| 705 | + In case the multi-modal token IDs exceed the vocabulary size of |
| 706 | + the language model, you can set `handle_oov_mm_token=False` |
| 707 | + to avoid calling the language model's `get_input_embeddings` method |
| 708 | + on those tokens. Note however that doing so increases memory usage |
| 709 | + as an additional buffer is needed to hold the input embeddings. |
| 710 | + """ |
| 711 | + inputs_embeds = self._get_text_embeddings( |
| 712 | + input_ids, |
| 713 | + self.get_language_model().get_input_embeddings, |
| 714 | + is_multimodal=is_multimodal, |
| 715 | + handle_oov_mm_token=handle_oov_mm_token, |
| 716 | + ) |
| 717 | + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: |
| 718 | + return inputs_embeds |
| 719 | + if is_multimodal is None: |
| 720 | + raise ValueError( |
| 721 | + "`get_input_embeddings` now requires `is_multimodal` arg, " |
| 722 | + "please update your model runner according to " |
| 723 | + "https://github.com/vllm-project/vllm/pull/16229.") |
| 724 | + if self.use_deepstack: |
| 725 | + ( |
| 726 | + deepstack_input_embeds, |
| 727 | + multimodal_embeddings, |
| 728 | + ) = self._compute_deepstack_embeds( |
| 729 | + inputs_embeds=inputs_embeds, |
| 730 | + multimodal_embeddings=multimodal_embeddings, |
| 731 | + is_multimodal=is_multimodal, |
| 732 | + ) |
| 733 | + else: |
| 734 | + deepstack_input_embeds = None |
| 735 | + inputs_embeds = _merge_multimodal_embeddings( |
| 736 | + inputs_embeds=inputs_embeds, |
| 737 | + is_multimodal=is_multimodal, |
| 738 | + multimodal_embeddings=multimodal_embeddings, |
| 739 | + ) |
| 740 | + if deepstack_input_embeds is not None: |
| 741 | + self._set_deepstack_input_embeds(deepstack_input_embeds) |
| 742 | + return inputs_embeds |
| 743 | + |
| 744 | + def _compute_deepstack_embeds( |
| 745 | + self, |
| 746 | + inputs_embeds: torch.Tensor, |
| 747 | + multimodal_embeddings: MultiModalEmbeddings, |
| 748 | + is_multimodal: torch.Tensor, |
| 749 | + ) -> tuple[torch.Tensor, MultiModalEmbeddings]: |
| 750 | + |
| 751 | + visual_lens = [len(x) for x in multimodal_embeddings] |
| 752 | + multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) |
| 753 | + |
| 754 | + total_dim = multimodal_embeddings_cat.shape[-1] |
| 755 | + assert total_dim == self.visual_dim + self.multiscale_dim, \ |
| 756 | + f"Total dimension mismatch: input {total_dim}, expected {self.visual_dim + self.multiscale_dim}" |
| 757 | + multimodal_embeddings_main = multimodal_embeddings_cat[ |
| 758 | + ..., :self.visual_dim] |
| 759 | + multimodal_embeddings_multiscale = multimodal_embeddings_cat[ |
| 760 | + ..., self.visual_dim:] |
| 761 | + |
| 762 | + multimodal_embeddings = torch.split(multimodal_embeddings_main, |
| 763 | + visual_lens, |
| 764 | + dim=0) |
| 765 | + multimodal_embeddings_multiscale = torch.split( |
| 766 | + multimodal_embeddings_multiscale, visual_lens, dim=0) |
| 767 | + |
| 768 | + deepstack_input_embeds = inputs_embeds.new_zeros( |
| 769 | + inputs_embeds.size(0), |
| 770 | + self.deepstack_num_level * inputs_embeds.size(1)) |
| 771 | + |
| 772 | + deepstack_input_embeds = _merge_multimodal_embeddings( |
| 773 | + inputs_embeds=deepstack_input_embeds, |
| 774 | + multimodal_embeddings=multimodal_embeddings_multiscale, |
| 775 | + is_multimodal=is_multimodal, |
| 776 | + ) |
| 777 | + deepstack_input_embeds = deepstack_input_embeds.view( |
| 778 | + inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim) |
| 779 | + deepstack_input_embeds = deepstack_input_embeds.permute( |
| 780 | + 1, 0, 2).contiguous() |
| 781 | + |
| 782 | + return deepstack_input_embeds, multimodal_embeddings |
0 commit comments