Skip to content

Commit 94bc6ba

Browse files
rahul-tuliclaude
andcommitted
feat: Add multimodal input support to Eagle3 Llama4
Enables multimodal input processing for Eagle3 speculative decoding with Llama4 models, supporting vision and other modalities. Key changes: - Updated get_input_embeddings to support multimodal embeddings - Added merge_multimodal_embeddings integration - Proper handling of image_token_index configuration - Maintains compatibility with existing text-only workflows Co-Authored-By: Claude <[email protected]> Signed-off-by: Rahul Tuli <[email protected]>
1 parent c1f9a25 commit 94bc6ba

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

vllm/model_executor/models/llama4_eagle3.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def forward(
143143
positions: torch.Tensor,
144144
hidden_states: torch.Tensor,
145145
inputs_embeds: Optional[torch.Tensor] = None,
146+
multimodal_embeddings: Optional[NestedTensors] = None,
146147
) -> tuple[torch.Tensor, torch.Tensor]:
147148
"""
148149
Forward pass for Eagle3 draft generation.
@@ -152,6 +153,7 @@ def forward(
152153
positions: Position indices for rotary embeddings
153154
hidden_states: Auxiliary hidden states from target model
154155
inputs_embeds: Pre-computed input embeddings (optional)
156+
multimodal_embeddings: Multimodal embeddings (optional)
155157
156158
Returns:
157159
Tuple of (hidden_states, hidden_states) following vLLM convention
@@ -160,6 +162,15 @@ def forward(
160162
if inputs_embeds is None:
161163
inputs_embeds = self.get_input_embeddings(input_ids)
162164

165+
# Apply multimodal embeddings if provided
166+
if multimodal_embeddings is not None:
167+
inputs_embeds = merge_multimodal_embeddings(
168+
input_ids,
169+
inputs_embeds,
170+
multimodal_embeddings,
171+
getattr(self.config, "image_token_index", None),
172+
)
173+
163174
# Eagle3 pattern: auxiliary hidden states have same dimension as embeddings
164175
# This assertion ensures compatibility for the single decoder layer
165176
assert hidden_states.shape[-1] == inputs_embeds.shape[-1], (
@@ -376,12 +387,6 @@ def forward(
376387
Returns:
377388
Tuple of (hidden_states, hidden_states) for vLLM compatibility
378389
"""
379-
if inputs_embeds is not None:
380-
raise NotImplementedError(
381-
f"{type(self).__name__} does not support multimodal inputs yet. "
382-
"Multimodal support for Eagle3 is planned for future releases."
383-
)
384-
385390
return self.model(input_ids, positions, hidden_states, inputs_embeds)
386391

387392
def compute_logits(

0 commit comments

Comments
 (0)