|
36 | 36 | import torch |
37 | 37 | from torch import nn |
38 | 38 | from torch.nn import CrossEntropyLoss |
39 | | -from torch.nn.attention.flex_attention import create_block_mask |
| 39 | +from torch.nn.attention.flex_attention import BlockMask, create_block_mask |
40 | 40 | from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel |
41 | 41 | from transformers.models.llama.modeling_llama import ( |
42 | 42 | LlamaAttention, |
@@ -451,9 +451,11 @@ def modify( |
451 | 451 | if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids: |
452 | 452 | layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook) |
453 | 453 |
|
| 454 | + self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later. |
454 | 455 | # compile and cach flex attention masks |
455 | 456 | self.cached_attn_blk_masks = [ |
456 | | - self._get_ttt_block_mask(self.eagle_config.training_seq_len, i) for i in range(3) |
| 457 | + self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i) |
| 458 | + for i in range(self.num_ttt_steps) |
457 | 459 | ] |
458 | 460 |
|
459 | 461 | def _prepare_decoder_attention_mask( |
@@ -533,44 +535,40 @@ def _get_eagle_module_inputs( |
533 | 535 |
|
534 | 536 | return eagle_input_ids, attention_mask, position_ids |
535 | 537 |
|
536 | | - def _get_ttt_block_mask(self, seq_length, ttt_step): |
537 | | - """Helper function to get block mask for TTT steps.""" |
| 538 | + def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: |
| 539 | + """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention.""" |
538 | 540 | if ttt_step == 0: |
539 | 541 |
|
540 | 542 | def msk(b, h, q_idx, kv_idx): |
| 543 | + # symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0 |
541 | 544 | return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length) |
542 | 545 |
|
543 | | - block_mask = create_block_mask( |
544 | | - msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2 |
545 | | - ) |
| 546 | + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2) |
546 | 547 | elif ttt_step == 1: |
547 | 548 |
|
548 | 549 | def msk(b, h, q_idx, kv_idx): |
| 550 | + # attention mask of shape [seq_len, 3* seq_len] for TTT step 1 |
549 | 551 | return ( |
550 | 552 | (kv_idx <= (q_idx - 2)) |
551 | 553 | | ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length)) |
552 | 554 | | ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2)) |
553 | 555 | ) |
554 | 556 |
|
555 | | - block_mask = create_block_mask( |
556 | | - msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3 |
557 | | - ) |
| 557 | + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3) |
558 | 558 | elif ttt_step == 2: |
559 | 559 |
|
560 | 560 | def msk(b, h, q_idx, kv_idx): |
| 561 | + # attention mask of shape [seq_len, 4* seq_len] for TTT step 2 |
561 | 562 | return ( |
562 | 563 | (kv_idx <= (q_idx - 3)) |
563 | 564 | | ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length)) |
564 | 565 | | ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2)) |
565 | 566 | | ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3)) |
566 | 567 | ) |
567 | 568 |
|
568 | | - block_mask = create_block_mask( |
569 | | - msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4 |
570 | | - ) |
| 569 | + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4) |
571 | 570 | else: |
572 | 571 | raise ValueError(f"EAGLE TTT step {ttt_step} is not supported") |
573 | | - return block_mask |
574 | 572 |
|
575 | 573 | def _base_model_forward( |
576 | 574 | self, |
@@ -763,7 +761,7 @@ def forward( |
763 | 761 | train_accs.append(acc) |
764 | 762 |
|
765 | 763 | # ====Perform training-time-testing with 3 extra eagle forward passes==== |
766 | | - for ttt_step in range(3): |
| 764 | + for ttt_step in range(self.num_ttt_steps): |
767 | 765 | eagle_input_hidden_states = torch.cat( |
768 | 766 | ( |
769 | 767 | torch.zeros( |
|
0 commit comments