Skip to content

Commit 1158d9a

Browse files
committed
add comments; make #ttt_steps configurable
Signed-off-by: h-guo18 <[email protected]>
1 parent 89c4475 commit 1158d9a

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import torch
3737
from torch import nn
3838
from torch.nn import CrossEntropyLoss
39-
from torch.nn.attention.flex_attention import create_block_mask
39+
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
4040
from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel
4141
from transformers.models.llama.modeling_llama import (
4242
LlamaAttention,
@@ -451,9 +451,11 @@ def modify(
451451
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
452452
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)
453453

454+
self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later.
454455
# compile and cach flex attention masks
455456
self.cached_attn_blk_masks = [
456-
self._get_ttt_block_mask(self.eagle_config.training_seq_len, i) for i in range(3)
457+
self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i)
458+
for i in range(self.num_ttt_steps)
457459
]
458460

459461
def _prepare_decoder_attention_mask(
@@ -533,44 +535,40 @@ def _get_eagle_module_inputs(
533535

534536
return eagle_input_ids, attention_mask, position_ids
535537

536-
def _get_ttt_block_mask(self, seq_length, ttt_step):
537-
"""Helper function to get block mask for TTT steps."""
538+
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
539+
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
538540
if ttt_step == 0:
539541

540542
def msk(b, h, q_idx, kv_idx):
543+
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
541544
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
542545

543-
block_mask = create_block_mask(
544-
msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2
545-
)
546+
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2)
546547
elif ttt_step == 1:
547548

548549
def msk(b, h, q_idx, kv_idx):
550+
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1
549551
return (
550552
(kv_idx <= (q_idx - 2))
551553
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
552554
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
553555
)
554556

555-
block_mask = create_block_mask(
556-
msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3
557-
)
557+
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3)
558558
elif ttt_step == 2:
559559

560560
def msk(b, h, q_idx, kv_idx):
561+
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2
561562
return (
562563
(kv_idx <= (q_idx - 3))
563564
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
564565
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
565566
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
566567
)
567568

568-
block_mask = create_block_mask(
569-
msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4
570-
)
569+
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4)
571570
else:
572571
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")
573-
return block_mask
574572

575573
def _base_model_forward(
576574
self,
@@ -763,7 +761,7 @@ def forward(
763761
train_accs.append(acc)
764762

765763
# ====Perform training-time-testing with 3 extra eagle forward passes====
766-
for ttt_step in range(3):
764+
for ttt_step in range(self.num_ttt_steps):
767765
eagle_input_hidden_states = torch.cat(
768766
(
769767
torch.zeros(

0 commit comments

Comments
 (0)