You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: native_sparse_attention_pytorch/native_sparse_attention.py
+14-1Lines changed: 14 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -191,6 +191,7 @@ def __init__(
191
191
num_compressed_mem_kv=1,
192
192
norm=True,
193
193
use_diff_topk=False,
194
+
use_triton_kernel=False,
194
195
interpolated_importance_score=False,
195
196
query_heads_share_selected_kv=True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
196
197
compress_mlp: Module|None=None,
@@ -287,6 +288,8 @@ def __init__(
287
288
288
289
self.num_selected_blocks=num_selected_blocks
289
290
291
+
self.use_triton_kernel=use_triton_kernel
292
+
290
293
# they combine the three sparse branches through a learned combine with sigmoid activation
291
294
292
295
ifnotexists(strategy_combine_mlp):
@@ -438,7 +441,17 @@ def forward(
438
441
gates=gates.cumprod(dim=-1)[..., -1]
439
442
gates=repeat(gates, 'b h ... -> b (h qh) ...', qh=fine_num_grouped_queries)
0 commit comments