Skip to content

Commit 730f04d

Browse files
committed
Override get_input_embeddings in Eagle3 to process text-only inputs
Implement custom get_input_embeddings() in Eagle3LlamaForCausalLM that accepts multimodal parameters but only processes text embeddings. This ensures the Llama3-based Eagle3 drafter correctly handles text inputs while remaining compatible with multimodal verifier interfaces. The drafter receives multimodal context through auxiliary hidden states from the verifier rather than processing multimodal inputs directly. Signed-off-by: rahul-tuli <[email protected]> Signed-off-by: Rahul Tuli <[email protected]>
1 parent 58dfcf6 commit 730f04d

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/model_executor/models/llama_eagle3.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2323
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
24+
from vllm.multimodal.inputs import NestedTensors
2425

2526
from .utils import AutoWeightsLoader, maybe_prefix
2627

@@ -241,8 +242,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
241242
requires_grad=False,
242243
)
243244

244-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
245-
return self.model.get_input_embeddings(input_ids)
245+
def get_input_embeddings(
246+
self,
247+
input_ids: torch.Tensor,
248+
multimodal_embeddings: Optional[NestedTensors] = None,
249+
is_multimodal: Optional[torch.Tensor] = None,
250+
) -> torch.Tensor:
251+
# The llama3 drafter only processes text embeddings
252+
return self.model.embed_tokens(input_ids)
246253

247254
def forward(
248255
self,

0 commit comments

Comments
 (0)