|
4 | 4 | from math import ceil |
5 | 5 |
|
6 | 6 | import torch |
7 | | -from torch import nn, arange, stack, cat |
| 7 | +from torch import nn, arange, stack, cat, Tensor |
8 | 8 | import torch.nn.functional as F |
9 | 9 | from torch.nn import Module, ModuleList |
10 | 10 |
|
@@ -65,29 +65,32 @@ def compress_mask(_, __, q_idx, kv_idx): |
65 | 65 | block_mask = create_block_mask(compress_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True) |
66 | 66 | return block_mask |
67 | 67 |
|
| 68 | +def create_fine_mask(seq_len, fine_block_size): |
68 | 69 |
|
69 | | -def create_fine_mask(selected_block_indices: Tensor, seq_len, fine_block_size): |
70 | | - device = selected_block_indices.device |
71 | | - batch, heads = selected_block_indices.shape[:2] |
| 70 | + def inner(selected_block_indices: Tensor): |
| 71 | + device = selected_block_indices.device |
| 72 | + batch, heads = selected_block_indices.shape[:2] |
72 | 73 |
|
73 | | - one_hot_selected_block_indices = torch.zeros((*selected_block_indices.shape[:-1], seq_len // fine_block_size), device = device, dtype = torch.bool) |
74 | | - one_hot_selected_block_indices.scatter_(-1, selected_block_indices, True) |
| 74 | + one_hot_selected_block_indices = torch.zeros((*selected_block_indices.shape[:-1], seq_len // fine_block_size), device = device, dtype = torch.bool) |
| 75 | + one_hot_selected_block_indices.scatter_(-1, selected_block_indices, True) |
75 | 76 |
|
76 | | - def fine_mask(b_idx, h_idx, q_idx, kv_idx): |
| 77 | + def fine_mask(b_idx, h_idx, q_idx, kv_idx): |
77 | 78 |
|
78 | | - compressed_q_idx = q_idx // fine_block_size |
79 | | - compressed_kv_idx = kv_idx // fine_block_size |
| 79 | + compressed_q_idx = q_idx // fine_block_size |
| 80 | + compressed_kv_idx = kv_idx // fine_block_size |
80 | 81 |
|
81 | | - block_causal_mask = compressed_q_idx > compressed_kv_idx |
82 | | - is_selected = one_hot_selected_block_indices[b_idx, h_idx, q_idx, compressed_kv_idx] |
| 82 | + block_causal_mask = compressed_q_idx > compressed_kv_idx |
| 83 | + is_selected = one_hot_selected_block_indices[b_idx, h_idx, q_idx, compressed_kv_idx] |
83 | 84 |
|
84 | | - causal_mask = q_idx >= kv_idx |
85 | | - block_diagonal = compressed_q_idx == compressed_kv_idx |
| 85 | + causal_mask = q_idx >= kv_idx |
| 86 | + block_diagonal = compressed_q_idx == compressed_kv_idx |
86 | 87 |
|
87 | | - return (causal_mask & block_diagonal) | (block_causal_mask & is_selected) |
| 88 | + return (causal_mask & block_diagonal) | (block_causal_mask & is_selected) |
88 | 89 |
|
89 | | - block_mask = create_block_mask(fine_mask, B = batch, H = heads, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True) |
90 | | - return block_mask |
| 90 | + block_mask = create_block_mask(fine_mask, B = batch, H = heads, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True) |
| 91 | + return block_mask |
| 92 | + |
| 93 | + return inner |
91 | 94 |
|
92 | 95 | # helpers |
93 | 96 |
|
@@ -241,7 +244,8 @@ def __init__( |
241 | 244 | def forward( |
242 | 245 | self, |
243 | 246 | inp, |
244 | | - sliding_window_flex_mask = None |
| 247 | + sliding_window_flex_mask = None, |
| 248 | + fine_selection_flex_mask = None |
245 | 249 | ): |
246 | 250 | batch, seq_len, scale, heads, device = *inp.shape[:2], self.scale, self.heads, inp.device |
247 | 251 |
|
|
0 commit comments