@@ -185,7 +185,8 @@ def __init__(self, config, decoder_layer_cls, bias=False):
185185 self .config = config
186186
187187 # Use flex attention for efficient TTT
188- config ._attn_implementation = "flex_attention"
188+ # config._attn_implementation = "flex_attention"
189+ config .attn_implementation = "sdpa"
189190
190191 self .layers = nn .ModuleList (
191192 [decoder_layer_cls (config , layer_idx ) for layer_idx in range (config .num_hidden_layers )]
@@ -373,6 +374,19 @@ def pop_aux_hidden_states(self):
373374
374375 return aux_h_list
375376
377+ def _get_base_model_parts (self ):
378+ """Helper function to extract model parts from different model types."""
379+ base_model = getattr (self , "model" , getattr (self , "backbone" , None ))
380+ base_model_embeddings = getattr (
381+ base_model , "embed_tokens" , getattr (base_model , "embeddings" , None )
382+ )
383+ base_model_lm_head = getattr (self , "lm_head" , None )
384+ # check if we find all parts
385+ for parts in [base_model , base_model_embeddings , base_model_lm_head ]:
386+ if not isinstance (parts , torch .nn .Module ):
387+ raise ValueError (f"Part { parts } is not a torch.nn.Module" )
388+ return base_model , base_model_embeddings , base_model_lm_head
389+
376390 def modify (
377391 self ,
378392 eagle_offline ,
@@ -426,34 +440,27 @@ def modify(
426440 )
427441 self .eagle_rotary_emb = LlamaRotaryEmbedding (config = self .eagle_config )
428442
429- if eagle_offline :
430- # For offline training, the base model has no layers.
431- # Read the device from the lm_head instead.
432- device = self .lm_head .weight .device
433- elif hasattr (self .model .layers [- 1 ].self_attn , "o_proj" ):
434- device = self .model .layers [- 1 ].self_attn .o_proj .weight .device
435- elif hasattr (self .model .layers [- 1 ].self_attn , "q_proj" ):
436- device = self .model .layers [- 1 ].self_attn .q_proj .weight .device
437- elif hasattr (self .model .layers [- 1 ].self_attn , "qkv_proj" ):
438- device = self .model .layers [- 1 ].self_attn .qkv_proj .weight .device
439- self .eagle_module .to (self .dtype ).to (device )
440-
441- # Make sure self.model.embed_tokens and self.lm_head are frozen
442- for param in self .model .embed_tokens .parameters ():
443+ self .base_model , self .base_model_embeddings , self .base_model_lm_head = (
444+ self ._get_base_model_parts ()
445+ )
446+ self .eagle_module .to (self .base_model .dtype ).to (self .base_model_lm_head .weight .device )
447+
448+ # Make sure word embedding and lm head are frozen
449+ for param in self .base_model_embeddings .parameters ():
443450 param .requires_grad = False
444- for param in self .lm_head .parameters ():
451+ for param in self .base_model_lm_head .parameters ():
445452 param .requires_grad = False
446453
447454 # EAGLE-3 auxiliary hidden_states
448455 if (not eagle_offline ) and self .eagle_config .use_aux_hidden_state :
449456 self ._aux_hidden_states = []
450- for layer_idx , layer in enumerate (self .model .layers ):
457+ for layer_idx , layer in enumerate (self .base_model .layers ):
451458 if layer_idx in self .eagle_config .eagle_aux_hidden_state_layer_ids :
452459 layer .register_forward_hook (self ._collect_aux_hidden_states_forward_hook )
453460
454461 # delete base model layers for offline training
455462 if eagle_offline :
456- self .model ._modules .pop ("layers" )
463+ self .base_model ._modules .pop ("layers" )
457464
458465 # NOTE: this is a temporary hack to bypass hf trainer check:
459466 # https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566
@@ -465,7 +472,9 @@ def modify(
465472 def _get_ttt_attention_mask (self , seq_length , ttt_step ):
466473 # compile and cached flex attention masks in first call
467474 if ttt_step >= len (self ._cached_attn_blk_masks ):
468- self ._cached_attn_blk_masks .append (self ._compile_ttt_block_mask (seq_length , ttt_step ))
475+ self ._cached_attn_blk_masks .append (
476+ self ._compute_ttt_attention_mask (seq_length , ttt_step )
477+ )
469478
470479 # return cached flex attention mask
471480 return self ._cached_attn_blk_masks [ttt_step ]
@@ -547,15 +556,14 @@ def _get_eagle_module_inputs(
547556
548557 return eagle_input_ids , attention_mask , position_ids
549558
550- def _compile_ttt_block_mask (self , seq_length , ttt_step ) -> BlockMask :
551- """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention ."""
559+ def _compute_ttt_attention_mask (self , seq_length , ttt_step ) -> BlockMask | torch . Tensor :
560+ """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl ."""
552561 if ttt_step == 0 :
553562
554563 def msk (b , h , q_idx , kv_idx ):
555564 # symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
556565 return (kv_idx <= (q_idx - 1 )) | (kv_idx == q_idx + seq_length )
557566
558- return create_block_mask (msk , B = None , H = None , Q_LEN = seq_length , KV_LEN = seq_length * 2 )
559567 elif ttt_step == 1 :
560568
561569 def msk (b , h , q_idx , kv_idx ):
@@ -565,8 +573,6 @@ def msk(b, h, q_idx, kv_idx):
565573 | ((kv_idx == q_idx + seq_length - 1 ) & (kv_idx >= seq_length ))
566574 | ((kv_idx == q_idx + 2 * seq_length ) & (kv_idx >= seq_length * 2 ))
567575 )
568-
569- return create_block_mask (msk , B = None , H = None , Q_LEN = seq_length , KV_LEN = seq_length * 3 )
570576 elif ttt_step == 2 :
571577
572578 def msk (b , h , q_idx , kv_idx ):
@@ -577,11 +583,27 @@ def msk(b, h, q_idx, kv_idx):
577583 | ((kv_idx == q_idx + 2 * seq_length - 1 ) & (kv_idx >= seq_length * 2 ))
578584 | ((kv_idx == q_idx + 3 * seq_length ) & (kv_idx >= seq_length * 3 ))
579585 )
580-
581- return create_block_mask (msk , B = None , H = None , Q_LEN = seq_length , KV_LEN = seq_length * 4 )
582586 else :
583587 raise ValueError (f"EAGLE TTT step { ttt_step } is not supported" )
584588
589+ dtypemin = torch .finfo (self .config .dtype ).min
590+ q_len = seq_length
591+ kv_len = seq_length * (2 + ttt_step )
592+ if self .eagle_module .config ._attn_implementation == "flex_attention" :
593+ block_mask = create_block_mask (msk , B = None , H = None , Q_LEN = q_len , KV_LEN = kv_len )
594+ return block_mask
595+ else :
596+ tensor_mask = msk (
597+ None ,
598+ None ,
599+ torch .arange (q_len ).view (1 , 1 , q_len , 1 ),
600+ torch .arange (kv_len ).view (1 , 1 , 1 , kv_len ),
601+ ).to (self .device )
602+ tensor_mask = torch .full_like (
603+ tensor_mask , 0 , dtype = self .config .dtype , device = self .device
604+ ).masked_fill (~ tensor_mask , dtypemin )
605+ return tensor_mask
606+
585607 def _base_model_forward (
586608 self ,
587609 input_ids ,
@@ -603,7 +625,7 @@ def _base_model_forward(
603625 output_hidden_states = True ,
604626 ** kwargs ,
605627 )
606- past_key_values = outputs . past_key_values
628+ past_key_values = getattr ( outputs , " past_key_values" , None )
607629 base_model_hidden_states = outputs .hidden_states [- 1 ]
608630 base_model_logits = outputs .logits
609631
@@ -748,7 +770,7 @@ def forward(
748770 eagle_cache ,
749771 )
750772 with torch .no_grad ():
751- inputs_embeds = self .model . embed_tokens (eagle_input_ids )
773+ inputs_embeds = self .base_model_embeddings (eagle_input_ids )
752774 position_embeddings = self .eagle_rotary_emb (eagle_input_hidden_states , position_ids )
753775
754776 # Then, we run eagle forward
@@ -921,7 +943,7 @@ def pseudo_speculative_generate(
921943 ):
922944 _ , eagle_prenorm_h , eagle_logits , _ = self ._eagle_forward (
923945 eagle_input_hidden_states ,
924- self .model . embed_tokens (eagle_ids ),
946+ self .base_model_embeddings (eagle_ids ),
925947 eagle_attention_mask ,
926948 eagle_position_ids ,
927949 position_embeddings ,
0 commit comments