|
39 | 39 | from vllm.model_executor.model_loader.weight_utils import (
|
40 | 40 | default_weight_loader, maybe_remap_kv_scale_name)
|
41 | 41 |
|
| 42 | +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP |
42 | 43 | from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
|
43 | 44 | from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
|
44 | 45 | is_pp_missing_parameter)
|
@@ -636,7 +637,8 @@ def load_weights(self, weights: Iterable[tuple[str,
|
636 | 637 | return loaded_params
|
637 | 638 |
|
638 | 639 |
|
639 |
| -class Llama4ForCausalLM(LlamaForCausalLM): |
| 640 | +class Llama4ForCausalLM(LlamaForCausalLM, SupportsLoRA, SupportsPP, |
| 641 | + SupportsEagle3): |
640 | 642 |
|
641 | 643 | packed_modules_mapping = {
|
642 | 644 | "qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
@@ -679,6 +681,30 @@ def load_weights(self, weights: Iterable[tuple[str,
|
679 | 681 | ]
|
680 | 682 | return loader.load_weights(weights)
|
681 | 683 |
|
| 684 | + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: |
| 685 | + """ |
| 686 | + Set auxiliary hidden state layers for Eagle3 speculation. |
| 687 | + |
| 688 | + Args: |
| 689 | + layers: Tuple of layer indices that should output auxiliary |
| 690 | + hidden states for Eagle3 speculation. |
| 691 | + """ |
| 692 | + self.model.aux_hidden_state_layers = layers |
| 693 | + |
| 694 | + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: |
| 695 | + """ |
| 696 | + Get the layer indices for Eagle3 auxiliary hidden states. |
| 697 | + |
| 698 | + Returns: |
| 699 | + Tuple of layer indices for auxiliary hidden state outputs. |
| 700 | + Typically includes early, middle, and late layers for optimal |
| 701 | + speculation performance. |
| 702 | + """ |
| 703 | + num_layers = len(self.model.layers) |
| 704 | + # Standard Eagle3 strategy: early, middle, and late layers |
| 705 | + # Ensures good representation across model depth |
| 706 | + return (2, num_layers // 2, num_layers - 3) |
| 707 | + |
682 | 708 | def permute_qk_weight_for_rotary(
|
683 | 709 | self,
|
684 | 710 | name: str,
|
|
0 commit comments