Skip to content

Commit 25def1f

Browse files
committed
Override get_input_embeddings in Eagle3 to process text-only inputs
Implement custom get_input_embeddings() in Eagle3LlamaForCausalLM that accepts multimodal parameters but only processes text embeddings. This ensures the Llama3-based Eagle3 drafter correctly handles text inputs while remaining compatible with multimodal verifier interfaces. The drafter receives multimodal context through auxiliary hidden states from the verifier rather than processing multimodal inputs directly. Signed-off-by: rahul-tuli <[email protected]>
1 parent fcaf21e commit 25def1f

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/model_executor/models/llama_eagle3.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2121
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
2222
LlamaForCausalLM)
23+
from vllm.multimodal.inputs import NestedTensors
2324

2425
from .utils import AutoWeightsLoader, maybe_prefix
2526

@@ -242,8 +243,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
242243
requires_grad=False,
243244
)
244245

245-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
246-
return self.model.get_input_embeddings(input_ids)
246+
def get_input_embeddings(
247+
self,
248+
input_ids: torch.Tensor,
249+
multimodal_embeddings: Optional[NestedTensors] = None,
250+
is_multimodal: Optional[torch.Tensor] = None,
251+
) -> torch.Tensor:
252+
# The llama3 drafter only processes text embeddings
253+
return self.model.embed_tokens(input_ids)
247254

248255
def forward(
249256
self,

0 commit comments

Comments
 (0)