Skip to content

Commit b1dee31

Browse files
committed
copy paste in a working impl of flash attention from ring-attention-pytorch repo for modification. get basic scaffolding ready
1 parent 582b844 commit b1dee31

File tree

3 files changed

+1172
-1
lines changed

3 files changed

+1172
-1
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def __init__(
191191
num_compressed_mem_kv = 1,
192192
norm = True,
193193
use_diff_topk = False,
194+
use_triton_kernel = False,
194195
interpolated_importance_score = False,
195196
query_heads_share_selected_kv = True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
196197
compress_mlp: Module | None = None,
@@ -287,6 +288,8 @@ def __init__(
287288

288289
self.num_selected_blocks = num_selected_blocks
289290

291+
self.use_triton_kernel = use_triton_kernel
292+
290293
# they combine the three sparse branches through a learned combine with sigmoid activation
291294

292295
if not exists(strategy_combine_mlp):
@@ -438,7 +441,17 @@ def forward(
438441
gates = gates.cumprod(dim = -1)[..., -1]
439442
gates = repeat(gates, 'b h ... -> b (h qh) ...', qh = fine_num_grouped_queries)
440443

441-
if exists(fine_selection_flex_mask):
444+
if self.use_triton_kernel:
445+
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend
446+
447+
fine_attn_out = native_sparse_attend(
448+
fq, fk, fv,
449+
self.selection_block_size,
450+
selected_block_indices,
451+
fine_num_grouped_queries
452+
)
453+
454+
elif exists(fine_selection_flex_mask):
442455
# flex attention for the selection for fine attention
443456

444457
fine_block_mask = fine_selection_flex_mask(selected_block_indices, num_grouped_queries = fine_num_grouped_queries)

native_sparse_attention_pytorch/transformer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def exists(v):
3535
def default(v, d):
3636
return v if exists(v) else d
3737

38+
def at_most_one_of(*bools):
39+
return sum([*map(int, bools)]) <= 1
40+
3841
# attention
3942

4043
class Attention(Module):
@@ -123,6 +126,7 @@ def __init__(
123126
use_sparse_attn = False,
124127
use_flex_sliding_window = False,
125128
use_flex_fine_selection = False,
129+
use_triton_fine_selection = False,
126130
sparse_attn_kwargs: dict = dict(
127131
sliding_window_size = 32,
128132
compress_block_size = 4,
@@ -131,6 +135,8 @@ def __init__(
131135
)
132136
):
133137
super().__init__()
138+
assert at_most_one_of(use_flex_fine_selection, use_triton_fine_selection), 'either using flex or custom triton kernel for fine attn, but not both'
139+
134140
self.token_emb = nn.Embedding(num_tokens, dim)
135141

136142
if use_flex_sliding_window or use_flex_fine_selection:
@@ -149,6 +155,7 @@ def __init__(
149155
dim_head = dim_head,
150156
heads = heads,
151157
kv_heads = kv_heads,
158+
use_triton_kernel = use_triton_fine_selection,
152159
**sparse_attn_kwargs
153160
)
154161
else:

0 commit comments

Comments
 (0)