Skip to content

Commit 20b8d9d

Browse files
committed
add support for nanov3
Signed-off-by: h-guo18 <[email protected]>
1 parent 41de55f commit 20b8d9d

File tree

2 files changed

+55
-30
lines changed

2 files changed

+55
-30
lines changed

examples/speculative_decoding/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ def train():
140140
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
141141
else:
142142
model = transformers.AutoModelForCausalLM.from_pretrained(
143-
model_args.model_name_or_path, torch_dtype="auto", device_map="cpu"
143+
model_args.model_name_or_path,
144+
torch_dtype="auto",
145+
device_map="cpu",
146+
trust_remote_code=True,
144147
)
145148
if use_offline_training:
146149
# When doing offline training, we need to set num_hidden_layers

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ def __init__(self, config, decoder_layer_cls, bias=False):
185185
self.config = config
186186

187187
# Use flex attention for efficient TTT
188-
config._attn_implementation = "flex_attention"
188+
# config._attn_implementation = "flex_attention"
189+
config.attn_implementation = "sdpa"
189190

190191
self.layers = nn.ModuleList(
191192
[decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
@@ -373,6 +374,19 @@ def pop_aux_hidden_states(self):
373374

374375
return aux_h_list
375376

377+
def _get_base_model_parts(self):
378+
"""Helper function to extract model parts from different model types."""
379+
base_model = getattr(self, "model", getattr(self, "backbone", None))
380+
base_model_embeddings = getattr(
381+
base_model, "embed_tokens", getattr(base_model, "embeddings", None)
382+
)
383+
base_model_lm_head = getattr(self, "lm_head", None)
384+
# check if we find all parts
385+
for parts in [base_model, base_model_embeddings, base_model_lm_head]:
386+
if not isinstance(parts, torch.nn.Module):
387+
raise ValueError(f"Part {parts} is not a torch.nn.Module")
388+
return base_model, base_model_embeddings, base_model_lm_head
389+
376390
def modify(
377391
self,
378392
eagle_offline,
@@ -426,34 +440,27 @@ def modify(
426440
)
427441
self.eagle_rotary_emb = LlamaRotaryEmbedding(config=self.eagle_config)
428442

429-
if eagle_offline:
430-
# For offline training, the base model has no layers.
431-
# Read the device from the lm_head instead.
432-
device = self.lm_head.weight.device
433-
elif hasattr(self.model.layers[-1].self_attn, "o_proj"):
434-
device = self.model.layers[-1].self_attn.o_proj.weight.device
435-
elif hasattr(self.model.layers[-1].self_attn, "q_proj"):
436-
device = self.model.layers[-1].self_attn.q_proj.weight.device
437-
elif hasattr(self.model.layers[-1].self_attn, "qkv_proj"):
438-
device = self.model.layers[-1].self_attn.qkv_proj.weight.device
439-
self.eagle_module.to(self.dtype).to(device)
440-
441-
# Make sure self.model.embed_tokens and self.lm_head are frozen
442-
for param in self.model.embed_tokens.parameters():
443+
self.base_model, self.base_model_embeddings, self.base_model_lm_head = (
444+
self._get_base_model_parts()
445+
)
446+
self.eagle_module.to(self.base_model.dtype).to(self.base_model_lm_head.weight.device)
447+
448+
# Make sure word embedding and lm head are frozen
449+
for param in self.base_model_embeddings.parameters():
443450
param.requires_grad = False
444-
for param in self.lm_head.parameters():
451+
for param in self.base_model_lm_head.parameters():
445452
param.requires_grad = False
446453

447454
# EAGLE-3 auxiliary hidden_states
448455
if (not eagle_offline) and self.eagle_config.use_aux_hidden_state:
449456
self._aux_hidden_states = []
450-
for layer_idx, layer in enumerate(self.model.layers):
457+
for layer_idx, layer in enumerate(self.base_model.layers):
451458
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
452459
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)
453460

454461
# delete base model layers for offline training
455462
if eagle_offline:
456-
self.model._modules.pop("layers")
463+
self.base_model._modules.pop("layers")
457464

458465
# NOTE: this is a temporary hack to bypass hf trainer check:
459466
# https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566
@@ -465,7 +472,9 @@ def modify(
465472
def _get_ttt_attention_mask(self, seq_length, ttt_step):
466473
# compile and cached flex attention masks in first call
467474
if ttt_step >= len(self._cached_attn_blk_masks):
468-
self._cached_attn_blk_masks.append(self._compile_ttt_block_mask(seq_length, ttt_step))
475+
self._cached_attn_blk_masks.append(
476+
self._compute_ttt_attention_mask(seq_length, ttt_step)
477+
)
469478

470479
# return cached flex attention mask
471480
return self._cached_attn_blk_masks[ttt_step]
@@ -547,15 +556,14 @@ def _get_eagle_module_inputs(
547556

548557
return eagle_input_ids, attention_mask, position_ids
549558

550-
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
551-
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
559+
def _compute_ttt_attention_mask(self, seq_length, ttt_step) -> BlockMask | torch.Tensor:
560+
"""Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl."""
552561
if ttt_step == 0:
553562

554563
def msk(b, h, q_idx, kv_idx):
555564
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
556565
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
557566

558-
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2)
559567
elif ttt_step == 1:
560568

561569
def msk(b, h, q_idx, kv_idx):
@@ -565,8 +573,6 @@ def msk(b, h, q_idx, kv_idx):
565573
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
566574
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
567575
)
568-
569-
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3)
570576
elif ttt_step == 2:
571577

572578
def msk(b, h, q_idx, kv_idx):
@@ -577,11 +583,27 @@ def msk(b, h, q_idx, kv_idx):
577583
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
578584
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
579585
)
580-
581-
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4)
582586
else:
583587
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")
584588

589+
dtypemin = torch.finfo(self.config.dtype).min
590+
q_len = seq_length
591+
kv_len = seq_length * (2 + ttt_step)
592+
if self.eagle_module.config._attn_implementation == "flex_attention":
593+
block_mask = create_block_mask(msk, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len)
594+
return block_mask
595+
else:
596+
tensor_mask = msk(
597+
None,
598+
None,
599+
torch.arange(q_len).view(1, 1, q_len, 1),
600+
torch.arange(kv_len).view(1, 1, 1, kv_len),
601+
).to(self.device)
602+
tensor_mask = torch.full_like(
603+
tensor_mask, 0, dtype=self.config.dtype, device=self.device
604+
).masked_fill(~tensor_mask, dtypemin)
605+
return tensor_mask
606+
585607
def _base_model_forward(
586608
self,
587609
input_ids,
@@ -603,7 +625,7 @@ def _base_model_forward(
603625
output_hidden_states=True,
604626
**kwargs,
605627
)
606-
past_key_values = outputs.past_key_values
628+
past_key_values = getattr(outputs, "past_key_values", None)
607629
base_model_hidden_states = outputs.hidden_states[-1]
608630
base_model_logits = outputs.logits
609631

@@ -748,7 +770,7 @@ def forward(
748770
eagle_cache,
749771
)
750772
with torch.no_grad():
751-
inputs_embeds = self.model.embed_tokens(eagle_input_ids)
773+
inputs_embeds = self.base_model_embeddings(eagle_input_ids)
752774
position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids)
753775

754776
# Then, we run eagle forward
@@ -921,7 +943,7 @@ def pseudo_speculative_generate(
921943
):
922944
_, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward(
923945
eagle_input_hidden_states,
924-
self.model.embed_tokens(eagle_ids),
946+
self.base_model_embeddings(eagle_ids),
925947
eagle_attention_mask,
926948
eagle_position_ids,
927949
position_embeddings,

0 commit comments

Comments
 (0)