Skip to content

Commit c1f9a25

Browse files
rahul-tuliclaude
andcommitted
feat: Add Eagle3 speculative decoding support for Llama4
Implements Eagle3 speculative decoding architecture for Llama4 models, enabling faster inference through single-layer draft model speculation. Key additions: - Eagle3Llama4ForCausalLM: Main implementation with single-layer draft architecture - SupportsEagle3 interface integration for Llama4ForCausalLM - Model registry mappings for Eagle3 Llama4 models - Auxiliary hidden state combination and vocabulary mapping - Draft-to-target token conversion for multi-vocabulary support 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> Signed-off-by: Rahul Tuli <[email protected]>
1 parent 0e3bb54 commit c1f9a25

File tree

3 files changed

+566
-1
lines changed

3 files changed

+566
-1
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm.model_executor.model_loader.weight_utils import (
4040
default_weight_loader, maybe_remap_kv_scale_name)
4141

42+
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
4243
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
4344
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
4445
is_pp_missing_parameter)
@@ -636,7 +637,8 @@ def load_weights(self, weights: Iterable[tuple[str,
636637
return loaded_params
637638

638639

639-
class Llama4ForCausalLM(LlamaForCausalLM):
640+
class Llama4ForCausalLM(LlamaForCausalLM, SupportsLoRA, SupportsPP,
641+
SupportsEagle3):
640642

641643
packed_modules_mapping = {
642644
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -679,6 +681,30 @@ def load_weights(self, weights: Iterable[tuple[str,
679681
]
680682
return loader.load_weights(weights)
681683

684+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
685+
"""
686+
Set auxiliary hidden state layers for Eagle3 speculation.
687+
688+
Args:
689+
layers: Tuple of layer indices that should output auxiliary
690+
hidden states for Eagle3 speculation.
691+
"""
692+
self.model.aux_hidden_state_layers = layers
693+
694+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
695+
"""
696+
Get the layer indices for Eagle3 auxiliary hidden states.
697+
698+
Returns:
699+
Tuple of layer indices for auxiliary hidden state outputs.
700+
Typically includes early, middle, and late layers for optimal
701+
speculation performance.
702+
"""
703+
num_layers = len(self.model.layers)
704+
# Standard Eagle3 strategy: early, middle, and late layers
705+
# Ensures good representation across model depth
706+
return (2, num_layers // 2, num_layers - 3)
707+
682708
def permute_qk_weight_for_rotary(
683709
self,
684710
name: str,

0 commit comments

Comments
 (0)