You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: native_sparse_attention_pytorch/native_sparse_attention.py
+21-9Lines changed: 21 additions & 9 deletions
Original file line number
Diff line number
Diff line change
@@ -189,6 +189,7 @@ def __init__(
189
189
norm=True,
190
190
use_diff_topk=False,
191
191
interpolated_importance_score=False,
192
+
query_heads_share_selected_kv=True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
192
193
compress_mlp: Module|None=None,
193
194
compress_mlp_expand_factor=1.,
194
195
strategy_combine_mlp: Module|None=None
@@ -272,6 +273,8 @@ def __init__(
272
273
273
274
self.interpolated_importance_score=interpolated_importance_score# in the case fine block size < compressed block size, will weigh space better when selecting
Copy file name to clipboardExpand all lines: train.py
+4-2Lines changed: 4 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -30,7 +30,8 @@
30
30
SEQ_LEN=256
31
31
32
32
USE_SPARSE_ATTN=True
33
-
USE_FLEX_FOR_FINE_SELECTION=True# will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
33
+
USE_FLEX_FOR_FINE_SELECTION=True# will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
34
+
QUERY_HEADS_SHARE_SELECTION=False# if set to False, each query head can look at a different segment of their corresponding key / value head in GQA
0 commit comments