@@ -143,6 +143,7 @@ def forward(
143
143
positions : torch .Tensor ,
144
144
hidden_states : torch .Tensor ,
145
145
inputs_embeds : Optional [torch .Tensor ] = None ,
146
+ multimodal_embeddings : Optional [NestedTensors ] = None ,
146
147
) -> tuple [torch .Tensor , torch .Tensor ]:
147
148
"""
148
149
Forward pass for Eagle3 draft generation.
@@ -152,6 +153,7 @@ def forward(
152
153
positions: Position indices for rotary embeddings
153
154
hidden_states: Auxiliary hidden states from target model
154
155
inputs_embeds: Pre-computed input embeddings (optional)
156
+ multimodal_embeddings: Multimodal embeddings (optional)
155
157
156
158
Returns:
157
159
Tuple of (hidden_states, hidden_states) following vLLM convention
@@ -160,6 +162,15 @@ def forward(
160
162
if inputs_embeds is None :
161
163
inputs_embeds = self .get_input_embeddings (input_ids )
162
164
165
+ # Apply multimodal embeddings if provided
166
+ if multimodal_embeddings is not None :
167
+ inputs_embeds = merge_multimodal_embeddings (
168
+ input_ids ,
169
+ inputs_embeds ,
170
+ multimodal_embeddings ,
171
+ getattr (self .config , "image_token_index" , None ),
172
+ )
173
+
163
174
# Eagle3 pattern: auxiliary hidden states have same dimension as embeddings
164
175
# This assertion ensures compatibility for the single decoder layer
165
176
assert hidden_states .shape [- 1 ] == inputs_embeds .shape [- 1 ], (
@@ -376,12 +387,6 @@ def forward(
376
387
Returns:
377
388
Tuple of (hidden_states, hidden_states) for vLLM compatibility
378
389
"""
379
- if inputs_embeds is not None :
380
- raise NotImplementedError (
381
- f"{ type (self ).__name__ } does not support multimodal inputs yet. "
382
- "Multimodal support for Eagle3 is planned for future releases."
383
- )
384
-
385
390
return self .model (input_ids , positions , hidden_states , inputs_embeds )
386
391
387
392
def compute_logits (
0 commit comments