Skip to content

Commit 936ad29

Browse files
committed
Implement SupportsEagle3 interface for Llama4 multimodal models
Add Eagle3 support to Llama4ForConditionalGeneration by implementing set_aux_hidden_state_layers() and get_eagle3_aux_hidden_state_layers() methods. Both methods delegate to the underlying Llama4ForCausalLM language model, enabling Eagle3 speculative decoding with Llama4 multimodal verifier models. This allows text-only Eagle3 drafters to work with Llama4 multimodal verifiers by consuming auxiliary hidden states from specified layers.
1 parent b7dce49 commit 936ad29

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

vllm/model_executor/models/mllama4.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@
5454
from vllm.sequence import IntermediateTensors
5555
from vllm.utils.tensor_schema import TensorSchema, TensorShape
5656

57-
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
57+
from .interfaces import (MultiModalEmbeddings, SupportsEagle3,
58+
SupportsMultiModal, SupportsPP)
5859
from .llama4 import Llama4ForCausalLM
5960
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
6061
from .vision import run_dp_sharded_vision_model
@@ -708,8 +709,8 @@ def get_dummy_mm_data(
708709
info=Mllama4ProcessingInfo,
709710
dummy_inputs=Mllama4DummyInputsBuilder,
710711
)
711-
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
712-
SupportsPP):
712+
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
713+
SupportsEagle3):
713714
packed_modules_mapping = {
714715
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
715716
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -758,6 +759,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
758759
self.make_empty_intermediate_tensors = (
759760
self.language_model.make_empty_intermediate_tensors)
760761

762+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
763+
"""Set which layers should output auxiliary hidden states for EAGLE3."""
764+
# Delegate to underlying language model (Llama4ForCausalLM)
765+
assert hasattr(self.language_model, 'set_aux_hidden_state_layers')
766+
self.language_model.set_aux_hidden_state_layers(layers)
767+
768+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
769+
"""Get the layer indices for auxiliary hidden state outputs.
770+
771+
Note: The GPU model runner will override this with layers from
772+
the speculative config if available, providing dynamic configuration.
773+
"""
774+
# Delegate to underlying language model (Llama4ForCausalLM)
775+
assert hasattr(self.language_model,
776+
'get_eagle3_aux_hidden_state_layers')
777+
self.language_model.get_eagle3_aux_hidden_state_layers()
778+
761779
def _parse_and_validate_image_input(
762780
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
763781
# num_images, 1, num_chunks, channel, image_size, image_size

0 commit comments

Comments
 (0)