|
54 | 54 | from vllm.sequence import IntermediateTensors
|
55 | 55 | from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
56 | 56 |
|
57 |
| -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP |
| 57 | +from .interfaces import (MultiModalEmbeddings, SupportsEagle3, |
| 58 | + SupportsMultiModal, SupportsPP) |
58 | 59 | from .llama4 import Llama4ForCausalLM
|
59 | 60 | from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
60 | 61 | from .vision import run_dp_sharded_vision_model
|
@@ -708,8 +709,8 @@ def get_dummy_mm_data(
|
708 | 709 | info=Mllama4ProcessingInfo,
|
709 | 710 | dummy_inputs=Mllama4DummyInputsBuilder,
|
710 | 711 | )
|
711 |
| -class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, |
712 |
| - SupportsPP): |
| 712 | +class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, |
| 713 | + SupportsEagle3): |
713 | 714 | packed_modules_mapping = {
|
714 | 715 | "qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
715 | 716 | "gate_up_proj": ["gate_proj", "up_proj"],
|
@@ -758,6 +759,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
758 | 759 | self.make_empty_intermediate_tensors = (
|
759 | 760 | self.language_model.make_empty_intermediate_tensors)
|
760 | 761 |
|
| 762 | + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: |
| 763 | + """Set which layers should output auxiliary hidden states for EAGLE3.""" |
| 764 | + # Delegate to underlying language model (Llama4ForCausalLM) |
| 765 | + assert hasattr(self.language_model, 'set_aux_hidden_state_layers') |
| 766 | + self.language_model.set_aux_hidden_state_layers(layers) |
| 767 | + |
| 768 | + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: |
| 769 | + """Get the layer indices for auxiliary hidden state outputs. |
| 770 | +
|
| 771 | + Note: The GPU model runner will override this with layers from |
| 772 | + the speculative config if available, providing dynamic configuration. |
| 773 | + """ |
| 774 | + # Delegate to underlying language model (Llama4ForCausalLM) |
| 775 | + assert hasattr(self.language_model, |
| 776 | + 'get_eagle3_aux_hidden_state_layers') |
| 777 | + self.language_model.get_eagle3_aux_hidden_state_layers() |
| 778 | + |
761 | 779 | def _parse_and_validate_image_input(
|
762 | 780 | self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
|
763 | 781 | # num_images, 1, num_chunks, channel, image_size, image_size
|
|
0 commit comments