diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 1aed13e87..ce033dbd0 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -38,12 +38,7 @@ from torch.nn import CrossEntropyLoss from torch.nn.attention.flex_attention import BlockMask, create_block_mask from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaRMSNorm, - LlamaRotaryEmbedding, -) +from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput @@ -53,6 +48,9 @@ from ..medusa.conversion import MedusaDMRegistry from ..medusa.medusa_model import MedusaModel from ..utils import AcceptanceRateValidation, ResBlock, temporary_set_config_value +from .modeling_deepseek import ( + DeepseekV3DecoderLayer, # This is the def from K2's modeling_deepseek.py +) IGNORE_TOKEN_ID = LabelSmoother.ignore_index @@ -219,8 +217,6 @@ def __init__(self, config, decoder_layer_cls, bias=False): ) first_layer_attn = self.layers[0].self_attn - if not isinstance(first_layer_attn, LlamaAttention): - raise ValueError("EAGLE-3 only support LlamaAttention.") # EAGLE-3's first attention require [input_layernorm_output, aux_hidden_states] first_layer_attn.register_forward_pre_hook( @@ -228,21 +224,31 @@ def __init__(self, config, decoder_layer_cls, bias=False): ) # Modify qkv projection in first layer to accept 2h hidden size. - first_layer_attn.q_proj = nn.Linear( - first_layer_attn.q_proj.in_features * 2, - first_layer_attn.q_proj.out_features, + first_layer_attn.q_a_proj = nn.Linear( + first_layer_attn.q_a_proj.in_features * 2, + first_layer_attn.q_a_proj.out_features, bias=first_layer_attn.config.attention_bias, ) - first_layer_attn.k_proj = nn.Linear( - first_layer_attn.k_proj.in_features * 2, - first_layer_attn.k_proj.out_features, - bias=first_layer_attn.config.attention_bias, - ) - first_layer_attn.v_proj = nn.Linear( - first_layer_attn.v_proj.in_features * 2, - first_layer_attn.v_proj.out_features, + first_layer_attn.kv_a_proj_with_mqa = nn.Linear( + first_layer_attn.kv_a_proj_with_mqa.in_features * 2, + first_layer_attn.kv_a_proj_with_mqa.out_features, bias=first_layer_attn.config.attention_bias, ) + # first_layer_attn.q_proj = nn.Linear( + # first_layer_attn.q_proj.in_features * 2, + # first_layer_attn.q_proj.out_features, + # bias=first_layer_attn.config.attention_bias, + # ) + # first_layer_attn.k_proj = nn.Linear( + # first_layer_attn.k_proj.in_features * 2, + # first_layer_attn.k_proj.out_features, + # bias=first_layer_attn.config.attention_bias, + # ) + # first_layer_attn.v_proj = nn.Linear( + # first_layer_attn.v_proj.in_features * 2, + # first_layer_attn.v_proj.out_features, + # bias=first_layer_attn.config.attention_bias, + # ) # In EAGLE-3, input_embeds and hidden_states are normalized separately before concatenation. self.input_embeds_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -309,7 +315,7 @@ def forward( hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) for decoder_layer in self.layers: - layer_outputs = decoder_layer( + hidden_states, past_key_values = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -318,11 +324,6 @@ def forward( use_cache=use_cache, position_embeddings=position_embeddings, ) - # For HF>= 4.54.0, the layer_outputs is a tensor, for older, it is a tuple. - if isinstance(layer_outputs, tuple): - hidden_states = layer_outputs[0] - else: - hidden_states = layer_outputs pre_norm_h = hidden_states @@ -445,6 +446,7 @@ def modify( Args: config: The config for eagle decoder layers. """ + eagle_reuse_base_decoder = True super().modify( eagle_offline=eagle_offline, eagle_hidden_state_distillation=eagle_hidden_state_distillation, @@ -458,9 +460,31 @@ def modify( self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config) if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" - decoder_cls = ( - type(self.model.layers[-1]) if self.eagle_reuse_base_decoder else LlamaDecoderLayer + + # tmp for MLA eagle: reuse base model's config to intialize eagle decoder layer + tmp_eagle_config = copy.deepcopy(self.config) + # set some necessary fields for MLA eagle + tmp_eagle_config.num_hidden_layers = self.eagle_config.num_hidden_layers + tmp_eagle_config.use_last_layernorm = self.eagle_config.use_last_layernorm + tmp_eagle_config.use_input_layernorm_in_first_layer = ( + self.eagle_config.use_input_layernorm_in_first_layer + ) + tmp_eagle_config.eagle_aux_hidden_state_layer_ids = ( + self.eagle_config.eagle_aux_hidden_state_layer_ids ) + tmp_eagle_config.use_aux_hidden_state = self.eagle_config.use_aux_hidden_state + tmp_eagle_config.use_mtp_layernorm = self.eagle_config.use_mtp_layernorm + tmp_eagle_config.draft_vocab_size = self.eagle_config.draft_vocab_size + tmp_eagle_config.has_lm_head = self.eagle_config.has_lm_head + + # hard code, this is same as (2, num_layers//2, num_layers-3) + tmp_eagle_config.eagle_aux_hidden_state_layer_ids = [2, 30, 58] + tmp_eagle_config._attn_implementation = "eager" + + # NOTE: by default, K2 use MLP instead of MOE for the first layer. so we leave it as is. + + # hard code to use K2's decoder type + decoder_cls = DeepseekV3DecoderLayer # Use default aux_hidden_state layers if use_aux_hidden_state is True # but no layer id is given @@ -478,10 +502,12 @@ def modify( ) self.eagle_module = EagleModule( - self.eagle_config, + tmp_eagle_config, decoder_cls, ) - self.eagle_rotary_emb = LlamaRotaryEmbedding(config=self.eagle_config) + # self.eagle_rotary_emb = LlamaRotaryEmbedding(config=self.eagle_config) + # DeepseekV3 only apply PE at attention + self.eagle_rotary_emb = nn.Identity() # find base model, lm head, and embeddings paths self._find_base_model_parts() @@ -826,7 +852,8 @@ def forward( with torch.no_grad(): inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs) - position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids) + # position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids) + position_embeddings = None past_key_values.eagle_cache = eagle_cache # ====Perform training-time-testing with 3 extra eagle forward passes====