|
36 | 36 | import torch
|
37 | 37 | from torch import nn
|
38 | 38 | from torch.nn import CrossEntropyLoss
|
39 |
| -from torch.nn.attention.flex_attention import create_block_mask |
| 39 | +from torch.nn.attention.flex_attention import BlockMask, create_block_mask |
40 | 40 | from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel
|
41 | 41 | from transformers.models.llama.modeling_llama import (
|
42 | 42 | LlamaAttention,
|
@@ -452,9 +452,11 @@ def modify(
|
452 | 452 | if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
|
453 | 453 | layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)
|
454 | 454 |
|
| 455 | + self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later. |
455 | 456 | # compile and cach flex attention masks
|
456 | 457 | self.cached_attn_blk_masks = [
|
457 |
| - self._get_ttt_block_mask(self.eagle_config.training_seq_len, i) for i in range(3) |
| 458 | + self._compile_ttt_block_mask(self.eagle_config.training_seq_len, i) |
| 459 | + for i in range(self.num_ttt_steps) |
458 | 460 | ]
|
459 | 461 |
|
460 | 462 | def _prepare_decoder_attention_mask(
|
@@ -534,44 +536,40 @@ def _get_eagle_module_inputs(
|
534 | 536 |
|
535 | 537 | return eagle_input_ids, attention_mask, position_ids
|
536 | 538 |
|
537 |
| - def _get_ttt_block_mask(self, seq_length, ttt_step): |
538 |
| - """Helper function to get block mask for TTT steps.""" |
| 539 | + def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: |
| 540 | + """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention.""" |
539 | 541 | if ttt_step == 0:
|
540 | 542 |
|
541 | 543 | def msk(b, h, q_idx, kv_idx):
|
| 544 | + # symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0 |
542 | 545 | return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
|
543 | 546 |
|
544 |
| - block_mask = create_block_mask( |
545 |
| - msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2 |
546 |
| - ) |
| 547 | + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2) |
547 | 548 | elif ttt_step == 1:
|
548 | 549 |
|
549 | 550 | def msk(b, h, q_idx, kv_idx):
|
| 551 | + # attention mask of shape [seq_len, 3* seq_len] for TTT step 1 |
550 | 552 | return (
|
551 | 553 | (kv_idx <= (q_idx - 2))
|
552 | 554 | | ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
|
553 | 555 | | ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
|
554 | 556 | )
|
555 | 557 |
|
556 |
| - block_mask = create_block_mask( |
557 |
| - msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3 |
558 |
| - ) |
| 558 | + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3) |
559 | 559 | elif ttt_step == 2:
|
560 | 560 |
|
561 | 561 | def msk(b, h, q_idx, kv_idx):
|
| 562 | + # attention mask of shape [seq_len, 4* seq_len] for TTT step 2 |
562 | 563 | return (
|
563 | 564 | (kv_idx <= (q_idx - 3))
|
564 | 565 | | ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
|
565 | 566 | | ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
|
566 | 567 | | ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
|
567 | 568 | )
|
568 | 569 |
|
569 |
| - block_mask = create_block_mask( |
570 |
| - msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4 |
571 |
| - ) |
| 570 | + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4) |
572 | 571 | else:
|
573 | 572 | raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")
|
574 |
| - return block_mask |
575 | 573 |
|
576 | 574 | def _base_model_forward(
|
577 | 575 | self,
|
@@ -764,7 +762,7 @@ def forward(
|
764 | 762 | train_accs.append(acc)
|
765 | 763 |
|
766 | 764 | # ====Perform training-time-testing with 3 extra eagle forward passes====
|
767 |
| - for ttt_step in range(3): |
| 765 | + for ttt_step in range(self.num_ttt_steps): |
768 | 766 | eagle_input_hidden_states = torch.cat(
|
769 | 767 | (
|
770 | 768 | torch.zeros(
|
|
0 commit comments