Skip to content

Commit 6bf3a1c

Browse files
committed
triton can now work in original training script
1 parent dba8d2b commit 6bf3a1c

File tree

2 files changed

+14
-233
lines changed

2 files changed

+14
-233
lines changed

train.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,16 @@
3232
KV_HEADS = 4
3333

3434
USE_SPARSE_ATTN = True
35-
USE_FLEX_FOR_FINE_SELECTION = True # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
35+
USE_TRITON_NSA = True
36+
USE_FLEX_FOR_FINE_SELECTION = False # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
3637
QUERY_HEADS_SHARE_SELECTION = False # if set to False, each query head can look at a different segment of their corresponding key / value head in GQA
3738

3839
# sparse attention related
3940

4041
SLIDING_WINDOW_SIZE = 64
41-
COMPRESS_BLOCK_SIZE = 64
42+
COMPRESS_BLOCK_SIZE = 16
4243

43-
FINE_BLOCK_SIZE = 32
44+
FINE_BLOCK_SIZE = 16
4445
NUM_FINE_SELECTED = 1
4546

4647
INTERPOLATED_IMPORTANCE_SCORE = False
@@ -108,6 +109,15 @@ def base_decoding(
108109

109110
return out[..., prompt_seq_len:]
110111

112+
# printing
113+
114+
if USE_TRITON_NSA:
115+
print('using custom triton kernel')
116+
elif USE_FLEX_FOR_FINE_SELECTION:
117+
print('using flex attn')
118+
else:
119+
print('sparse attn in regular pytorch')
120+
111121
# model
112122

113123
model = Transformer(
@@ -119,6 +129,7 @@ def base_decoding(
119129
kv_heads = KV_HEADS,
120130
use_sparse_attn = USE_SPARSE_ATTN,
121131
use_flex_sliding_window = True,
132+
use_triton_fine_selection = USE_TRITON_NSA,
122133
use_flex_fine_selection = USE_FLEX_FOR_FINE_SELECTION,
123134
sparse_attn_kwargs = dict(
124135
sliding_window_size = SLIDING_WINDOW_SIZE,

train_triton_nsa.py

Lines changed: 0 additions & 230 deletions
This file was deleted.

0 commit comments

Comments
 (0)