Skip to content

Commit 58dfcf6

Browse files
committed
Implement SupportsEagle3 interface for Llama4 multimodal models
Add Eagle3 support to Llama4ForConditionalGeneration by implementing set_aux_hidden_state_layers() and get_eagle3_aux_hidden_state_layers() methods. Both methods delegate to the underlying Llama4ForCausalLM language model, enabling Eagle3 speculative decoding with Llama4 multimodal verifier models. This allows text-only Eagle3 drafters to work with Llama4 multimodal verifiers by consuming auxiliary hidden states from specified layers. Signed-off-by: rahul-tuli <[email protected]> Signed-off-by: Rahul Tuli <[email protected]>
1 parent 07e7c78 commit 58dfcf6

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

vllm/model_executor/models/mllama4.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@
6464
from vllm.sequence import IntermediateTensors
6565
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6666

67-
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
67+
from .interfaces import (MultiModalEmbeddings, SupportsEagle3,
68+
SupportsMultiModal, SupportsPP)
6869
from .llama4 import Llama4ForCausalLM
6970
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
7071
from .vision import run_dp_sharded_vision_model
@@ -717,7 +718,9 @@ def get_dummy_mm_data(
717718
info=Mllama4ProcessingInfo,
718719
dummy_inputs=Mllama4DummyInputsBuilder,
719720
)
720-
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
721+
class Llama4ForConditionalGeneration(
722+
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
723+
):
721724
packed_modules_mapping = {
722725
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
723726
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -767,6 +770,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
767770
self.language_model.make_empty_intermediate_tensors
768771
)
769772

773+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
774+
"""Set which layers should output auxiliary hidden states for EAGLE3."""
775+
# Delegate to underlying language model (Llama4ForCausalLM)
776+
assert hasattr(self.language_model, 'set_aux_hidden_state_layers')
777+
self.language_model.set_aux_hidden_state_layers(layers)
778+
779+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
780+
"""Get the layer indices for auxiliary hidden state outputs.
781+
782+
Note: The GPU model runner will override this with layers from
783+
the speculative config if available, providing dynamic configuration.
784+
"""
785+
# Delegate to underlying language model (Llama4ForCausalLM)
786+
assert hasattr(
787+
self.language_model, "get_eagle3_aux_hidden_state_layers"
788+
)
789+
return self.language_model.get_eagle3_aux_hidden_state_layers()
790+
770791
def _parse_and_validate_image_input(
771792
self, **kwargs: object
772793
) -> Optional[Llama4ImagePatchInputs]:

0 commit comments

Comments
 (0)