Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 57 additions & 30 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,7 @@
from torch.nn import CrossEntropyLoss
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.trainer_pt_utils import LabelSmoother
from transformers.utils import ModelOutput

Expand All @@ -53,6 +48,9 @@
from ..medusa.conversion import MedusaDMRegistry
from ..medusa.medusa_model import MedusaModel
from ..utils import AcceptanceRateValidation, ResBlock, temporary_set_config_value
from .modeling_deepseek import (
DeepseekV3DecoderLayer, # This is the def from K2's modeling_deepseek.py
)

IGNORE_TOKEN_ID = LabelSmoother.ignore_index

Expand Down Expand Up @@ -219,30 +217,38 @@ def __init__(self, config, decoder_layer_cls, bias=False):
)

first_layer_attn = self.layers[0].self_attn
if not isinstance(first_layer_attn, LlamaAttention):
raise ValueError("EAGLE-3 only support LlamaAttention.")

# EAGLE-3's first attention require [input_layernorm_output, aux_hidden_states]
first_layer_attn.register_forward_pre_hook(
self._eagle3_attention_forward_pre_hook, with_kwargs=True
)

# Modify qkv projection in first layer to accept 2h hidden size.
first_layer_attn.q_proj = nn.Linear(
first_layer_attn.q_proj.in_features * 2,
first_layer_attn.q_proj.out_features,
first_layer_attn.q_a_proj = nn.Linear(
first_layer_attn.q_a_proj.in_features * 2,
first_layer_attn.q_a_proj.out_features,
bias=first_layer_attn.config.attention_bias,
)
first_layer_attn.k_proj = nn.Linear(
first_layer_attn.k_proj.in_features * 2,
first_layer_attn.k_proj.out_features,
bias=first_layer_attn.config.attention_bias,
)
first_layer_attn.v_proj = nn.Linear(
first_layer_attn.v_proj.in_features * 2,
first_layer_attn.v_proj.out_features,
first_layer_attn.kv_a_proj_with_mqa = nn.Linear(
first_layer_attn.kv_a_proj_with_mqa.in_features * 2,
first_layer_attn.kv_a_proj_with_mqa.out_features,
bias=first_layer_attn.config.attention_bias,
)
# first_layer_attn.q_proj = nn.Linear(
# first_layer_attn.q_proj.in_features * 2,
# first_layer_attn.q_proj.out_features,
# bias=first_layer_attn.config.attention_bias,
# )
# first_layer_attn.k_proj = nn.Linear(
# first_layer_attn.k_proj.in_features * 2,
# first_layer_attn.k_proj.out_features,
# bias=first_layer_attn.config.attention_bias,
# )
# first_layer_attn.v_proj = nn.Linear(
# first_layer_attn.v_proj.in_features * 2,
# first_layer_attn.v_proj.out_features,
# bias=first_layer_attn.config.attention_bias,
# )

# In EAGLE-3, input_embeds and hidden_states are normalized separately before concatenation.
self.input_embeds_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -309,7 +315,7 @@ def forward(
hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))

for decoder_layer in self.layers:
layer_outputs = decoder_layer(
hidden_states, past_key_values = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
Expand All @@ -318,11 +324,6 @@ def forward(
use_cache=use_cache,
position_embeddings=position_embeddings,
)
# For HF>= 4.54.0, the layer_outputs is a tensor, for older, it is a tuple.
if isinstance(layer_outputs, tuple):
hidden_states = layer_outputs[0]
else:
hidden_states = layer_outputs

pre_norm_h = hidden_states

Expand Down Expand Up @@ -445,6 +446,7 @@ def modify(
Args:
config: The config for eagle decoder layers.
"""
eagle_reuse_base_decoder = True
super().modify(
eagle_offline=eagle_offline,
eagle_hidden_state_distillation=eagle_hidden_state_distillation,
Expand All @@ -458,9 +460,31 @@ def modify(
self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config)
if self.eagle_config._attn_implementation is None:
self.eagle_config._attn_implementation = "sdpa"
decoder_cls = (
type(self.model.layers[-1]) if self.eagle_reuse_base_decoder else LlamaDecoderLayer

# tmp for MLA eagle: reuse base model's config to intialize eagle decoder layer
tmp_eagle_config = copy.deepcopy(self.config)
# set some necessary fields for MLA eagle
tmp_eagle_config.num_hidden_layers = self.eagle_config.num_hidden_layers
tmp_eagle_config.use_last_layernorm = self.eagle_config.use_last_layernorm
tmp_eagle_config.use_input_layernorm_in_first_layer = (
self.eagle_config.use_input_layernorm_in_first_layer
)
tmp_eagle_config.eagle_aux_hidden_state_layer_ids = (
self.eagle_config.eagle_aux_hidden_state_layer_ids
)
tmp_eagle_config.use_aux_hidden_state = self.eagle_config.use_aux_hidden_state
tmp_eagle_config.use_mtp_layernorm = self.eagle_config.use_mtp_layernorm
tmp_eagle_config.draft_vocab_size = self.eagle_config.draft_vocab_size
tmp_eagle_config.has_lm_head = self.eagle_config.has_lm_head

# hard code, this is same as (2, num_layers//2, num_layers-3)
tmp_eagle_config.eagle_aux_hidden_state_layer_ids = [2, 30, 58]
tmp_eagle_config._attn_implementation = "eager"

# NOTE: by default, K2 use MLP instead of MOE for the first layer. so we leave it as is.

# hard code to use K2's decoder type
decoder_cls = DeepseekV3DecoderLayer

# Use default aux_hidden_state layers if use_aux_hidden_state is True
# but no layer id is given
Expand All @@ -478,10 +502,12 @@ def modify(
)

self.eagle_module = EagleModule(
self.eagle_config,
tmp_eagle_config,
decoder_cls,
)
self.eagle_rotary_emb = LlamaRotaryEmbedding(config=self.eagle_config)
# self.eagle_rotary_emb = LlamaRotaryEmbedding(config=self.eagle_config)
# DeepseekV3 only apply PE at attention
self.eagle_rotary_emb = nn.Identity()

# find base model, lm head, and embeddings paths
self._find_base_model_parts()
Expand Down Expand Up @@ -826,7 +852,8 @@ def forward(
with torch.no_grad():
inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs)

position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids)
# position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids)
position_embeddings = None
past_key_values.eagle_cache = eagle_cache

# ====Perform training-time-testing with 3 extra eagle forward passes====
Expand Down
Loading