|
36 | 36 | import torch |
37 | 37 | from torch import nn |
38 | 38 | from torch.nn import CrossEntropyLoss |
39 | | -from torch.nn.attention.flex_attention import create_block_mask |
| 39 | +from torch.nn.attention.flex_attention import BlockMask, create_block_mask |
40 | 40 | from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel |
41 | 41 | from transformers.models.llama.modeling_llama import ( |
42 | 42 | LlamaAttention, |
@@ -452,9 +452,11 @@ def modify( |
452 | 452 | if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids: |
453 | 453 | layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook) |
454 | 454 |
|
| 455 | + self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later. |
455 | 456 | # compile and cach flex attention masks |
456 | 457 | self.cached_attn_blk_masks = [ |
457 | | - self._get_ttt_block_mask(self.eagle_config.training_seq_len, i) for i in range(3) |
| 458 | + self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i) |
| 459 | + for i in range(self.num_ttt_steps) |
458 | 460 | ] |
459 | 461 |
|
460 | 462 | def _prepare_decoder_attention_mask( |
@@ -534,44 +536,40 @@ def _get_eagle_module_inputs( |
534 | 536 |
|
535 | 537 | return eagle_input_ids, attention_mask, position_ids |
536 | 538 |
|
537 | | - def _get_ttt_block_mask(self, seq_length, ttt_step): |
538 | | - """Helper function to get block mask for TTT steps.""" |
| 539 | + def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: |
| 540 | + """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention.""" |
539 | 541 | if ttt_step == 0: |
540 | 542 |
|
541 | 543 | def msk(b, h, q_idx, kv_idx): |
| 544 | + # symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0 |
542 | 545 | return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length) |
543 | 546 |
|
544 | | - block_mask = create_block_mask( |
545 | | - msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2 |
546 | | - ) |
| 547 | + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2) |
547 | 548 | elif ttt_step == 1: |
548 | 549 |
|
549 | 550 | def msk(b, h, q_idx, kv_idx): |
| 551 | + # attention mask of shape [seq_len, 3* seq_len] for TTT step 1 |
550 | 552 | return ( |
551 | 553 | (kv_idx <= (q_idx - 2)) |
552 | 554 | | ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length)) |
553 | 555 | | ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2)) |
554 | 556 | ) |
555 | 557 |
|
556 | | - block_mask = create_block_mask( |
557 | | - msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3 |
558 | | - ) |
| 558 | + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3) |
559 | 559 | elif ttt_step == 2: |
560 | 560 |
|
561 | 561 | def msk(b, h, q_idx, kv_idx): |
| 562 | + # attention mask of shape [seq_len, 4* seq_len] for TTT step 2 |
562 | 563 | return ( |
563 | 564 | (kv_idx <= (q_idx - 3)) |
564 | 565 | | ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length)) |
565 | 566 | | ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2)) |
566 | 567 | | ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3)) |
567 | 568 | ) |
568 | 569 |
|
569 | | - block_mask = create_block_mask( |
570 | | - msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4 |
571 | | - ) |
| 570 | + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4) |
572 | 571 | else: |
573 | 572 | raise ValueError(f"EAGLE TTT step {ttt_step} is not supported") |
574 | | - return block_mask |
575 | 573 |
|
576 | 574 | def _base_model_forward( |
577 | 575 | self, |
@@ -764,7 +762,7 @@ def forward( |
764 | 762 | train_accs.append(acc) |
765 | 763 |
|
766 | 764 | # ====Perform training-time-testing with 3 extra eagle forward passes==== |
767 | | - for ttt_step in range(3): |
| 765 | + for ttt_step in range(self.num_ttt_steps): |
768 | 766 | eagle_input_hidden_states = torch.cat( |
769 | 767 | ( |
770 | 768 | torch.zeros( |
|
0 commit comments