Skip to content

Commit e4a0b13

Browse files
committed
add comments; make #ttt_steps configurable
Signed-off-by: h-guo18 <[email protected]>
1 parent d6d3eba commit e4a0b13

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import torch
3737
from torch import nn
3838
from torch.nn import CrossEntropyLoss
39-
from torch.nn.attention.flex_attention import create_block_mask
39+
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
4040
from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel
4141
from transformers.models.llama.modeling_llama import (
4242
LlamaAttention,
@@ -452,9 +452,11 @@ def modify(
452452
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
453453
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)
454454

455+
self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later.
455456
# compile and cach flex attention masks
456457
self.cached_attn_blk_masks = [
457-
self._get_ttt_block_mask(self.eagle_config.training_seq_len, i) for i in range(3)
458+
self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i)
459+
for i in range(self.num_ttt_steps)
458460
]
459461

460462
def _prepare_decoder_attention_mask(
@@ -534,44 +536,40 @@ def _get_eagle_module_inputs(
534536

535537
return eagle_input_ids, attention_mask, position_ids
536538

537-
def _get_ttt_block_mask(self, seq_length, ttt_step):
538-
"""Helper function to get block mask for TTT steps."""
539+
def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask:
540+
"""Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention."""
539541
if ttt_step == 0:
540542

541543
def msk(b, h, q_idx, kv_idx):
544+
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
542545
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
543546

544-
block_mask = create_block_mask(
545-
msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2
546-
)
547+
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2)
547548
elif ttt_step == 1:
548549

549550
def msk(b, h, q_idx, kv_idx):
551+
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1
550552
return (
551553
(kv_idx <= (q_idx - 2))
552554
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
553555
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
554556
)
555557

556-
block_mask = create_block_mask(
557-
msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3
558-
)
558+
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3)
559559
elif ttt_step == 2:
560560

561561
def msk(b, h, q_idx, kv_idx):
562+
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2
562563
return (
563564
(kv_idx <= (q_idx - 3))
564565
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
565566
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
566567
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
567568
)
568569

569-
block_mask = create_block_mask(
570-
msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4
571-
)
570+
return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4)
572571
else:
573572
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")
574-
return block_mask
575573

576574
def _base_model_forward(
577575
self,
@@ -764,7 +762,7 @@ def forward(
764762
train_accs.append(acc)
765763

766764
# ====Perform training-time-testing with 3 extra eagle forward passes====
767-
for ttt_step in range(3):
765+
for ttt_step in range(self.num_ttt_steps):
768766
eagle_input_hidden_states = torch.cat(
769767
(
770768
torch.zeros(

0 commit comments

Comments
 (0)