Skip to content

Commit 61274bd

Browse files
[Doc] Further update multi-modal impl doc (vllm-project#33065)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent b40db4d commit 61274bd

File tree

1 file changed

+65
-20
lines changed

1 file changed

+65
-20
lines changed

docs/contributing/model/multimodal.md

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,28 +43,73 @@ Further update the model as follows:
4343
)
4444
```
4545

46-
- Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
46+
- Remove the embedding part from the [forward][torch.nn.Module.forward] method:
47+
- Move the multi-modal embedding to [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal].
48+
- The text embedding and embedding merge are handled automatically by a default implementation of [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids]. It does not need to be overridden in most cases.
49+
50+
```diff
51+
def forward(
52+
self,
53+
input_ids: torch.Tensor | None,
54+
- pixel_values: torch.Tensor,
55+
positions: torch.Tensor,
56+
intermediate_tensors: IntermediateTensors | None = None,
57+
inputs_embeds: torch.Tensor | None = None,
58+
) -> torch.Tensor:
59+
- if inputs_embeds is None:
60+
- inputs_embeds = self.get_input_embeddings()(input_ids)
61+
-
62+
- if pixel_values is not None:
63+
- image_features = self.get_image_features(
64+
- pixel_values=pixel_values,
65+
- )
66+
- special_image_mask = self.get_placeholder_mask(
67+
- input_ids,
68+
- inputs_embeds=inputs_embeds,
69+
- image_features=image_features,
70+
- )
71+
- inputs_embeds = inputs_embeds.masked_scatter(
72+
- special_image_mask,
73+
- image_features,
74+
- )
75+
76+
hidden_states = self.language_model(
77+
input_ids,
78+
positions,
79+
intermediate_tensors,
80+
inputs_embeds=inputs_embeds,
81+
)
82+
...
83+
84+
+ def embed_multimodal(
85+
+ self,
86+
+ pixel_values: torch.Tensor,
87+
+ ) -> MultiModalEmbeddings | None:
88+
+ return self.get_image_features(
89+
+ pixel_values=pixel_values,
90+
+ )
91+
```
4792

48-
??? code
93+
Below we provide a boilerplate of a typical implementation pattern of [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal], but feel free to adjust it to your own needs.
4994

50-
```python
51-
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
52-
image_features = self.vision_encoder(image_input)
53-
return self.multi_modal_projector(image_features)
54-
55-
def embed_multimodal(
56-
self,
57-
**kwargs: object,
58-
) -> MultiModalEmbeddings | None:
59-
# Validate the multimodal input keyword arguments
60-
image_input = self._parse_and_validate_image_input(**kwargs)
61-
if image_input is None:
62-
return None
63-
64-
# Run multimodal inputs through encoder and projector
65-
vision_embeddings = self._process_image_input(image_input)
66-
return vision_embeddings
67-
```
95+
```python
96+
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
97+
image_features = self.vision_encoder(image_input)
98+
return self.multi_modal_projector(image_features)
99+
100+
def embed_multimodal(
101+
self,
102+
**kwargs: object,
103+
) -> MultiModalEmbeddings | None:
104+
# Validate the multimodal input keyword arguments
105+
image_input = self._parse_and_validate_image_input(**kwargs)
106+
if image_input is None:
107+
return None
108+
109+
# Run multimodal inputs through encoder and projector
110+
vision_embeddings = self._process_image_input(image_input)
111+
return vision_embeddings
112+
```
68113

69114
!!! important
70115
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.

0 commit comments

Comments
 (0)