Skip to content

Commit e1d9419

Browse files
committed
flex for compress attn pathway
1 parent d18b83d commit e1d9419

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ def sliding_mask(_, __, q_idx, kv_idx):
5252
block_mask = create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
5353
return block_mask
5454

55+
def create_compress_mask(seq_len, kv_seq_len, compress_block_size):
56+
# cannot be used as using attention logits for importance score
57+
# but just to show the immense potential of flex attention
58+
59+
def compress_mask(_, __, q_idx, kv_idx):
60+
compress_kv_idx = (kv_idx * compress_block_size) + (compress_block_size - 1)
61+
62+
causal_mask = q_idx >= compress_kv_idx
63+
return causal_mask
64+
65+
block_mask = create_block_mask(compress_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True)
66+
return block_mask
67+
5568
# helpers
5669

5770
def exists(v):

native_sparse_attention_pytorch/transformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
from rotary_embedding_torch import RotaryEmbedding
1010

11-
from native_sparse_attention_pytorch.native_sparse_attention import SparseAttention, create_sliding_mask
11+
from native_sparse_attention_pytorch.native_sparse_attention import (
12+
SparseAttention,
13+
create_sliding_mask,
14+
create_compress_mask,
15+
)
1216

1317
# flex attention
1418
# https://pytorch.org/blog/flexattention/

0 commit comments

Comments
 (0)