Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import is_interleaved

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
Expand Down Expand Up @@ -439,7 +439,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -485,6 +485,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
"""Set auxiliary hidden state layers for Eagle3 speculation."""
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
"""Get the layer indices for Eagle3 auxiliary hidden states."""
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def forward(
self,
input_ids: torch.Tensor,
Expand Down