Skip to content

Commit 2cf4646

Browse files
committed
train script
1 parent 0ad8c5e commit 2cf4646

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

train.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,31 @@
2626
VALIDATE_EVERY = 100
2727
PRIME_LENGTH = 64
2828
GENERATE_EVERY = 500
29-
GENERATE_LENGTH = 256
30-
SEQ_LEN = 256
29+
GENERATE_LENGTH = 512
30+
SEQ_LEN = 512
31+
HEAD = 8
32+
KV_HEADS = 4
3133

32-
USE_SPARSE_ATTN = True
34+
USE_SPARSE_ATTN = False
3335
USE_FLEX_FOR_FINE_SELECTION = True # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
3436
QUERY_HEADS_SHARE_SELECTION = False # if set to False, each query head can look at a different segment of their corresponding key / value head in GQA
3537

38+
# sparse attention related
39+
40+
SLIDING_WINDOW_SIZE = 64
41+
COMPRESS_BLOCK_SIZE = 64
42+
43+
FINE_BLOCK_SIZE = 32
44+
NUM_FINE_SELECTED = 0
45+
46+
INTERPOLATED_IMPORTANCE_SCORE = True
47+
USE_DIFF_TOPK = True
48+
3649
# experiment related
3750

3851
PROJECT_NAME = 'native-sparse-attention'
39-
RUN_NAME = 'baseline' if not USE_SPARSE_ATTN else 'sparse-attn'
40-
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
52+
RUN_NAME = 'baseline' if not USE_SPARSE_ATTN else f'sparse-attn: compress size {COMPRESS_BLOCK_SIZE} | fine size {FINE_BLOCK_SIZE} | {NUM_FINE_SELECTED} selected'
53+
WANDB_ONLINE = True # turn this on to pipe experiment to cloud
4154

4255
# helpers
4356

@@ -101,24 +114,24 @@ def base_decoding(
101114
num_tokens = 256,
102115
dim = 512,
103116
depth = 6,
104-
heads = 8,
117+
heads = HEADS,
105118
dim_head = 64,
106-
kv_heads = 4,
119+
kv_heads = KV_HEADS,
107120
use_sparse_attn = USE_SPARSE_ATTN,
108121
use_flex_sliding_window = True,
109122
use_flex_fine_selection = USE_FLEX_FOR_FINE_SELECTION,
110123
sparse_attn_kwargs = dict(
111-
sliding_window_size = 32,
112-
compress_block_size = 32,
124+
sliding_window_size = SLIDING_WINDOW_SIZE,
125+
compress_block_size = COMPRESS_BLOCK_SIZE,
113126
compress_mlp = GroupedMLP(
114127
dim_head = 64,
115-
compress_block_size = 32,
116-
heads= 4,
128+
compress_block_size = COMPRESS_BLOCK_SIZE,
129+
heads = KV_HEADS,
117130
),
118-
selection_block_size = 32,
119-
num_selected_blocks = 2,
120-
use_diff_topk = True,
121-
interpolated_importance_score = True,
131+
selection_block_size = FINE_BLOCK_SIZE,
132+
num_selected_blocks = NUM_FINE_SELECTED,
133+
use_diff_topk = USE_DIFF_TOPK,
134+
interpolated_importance_score = INTERPOLATED_IMPORTANCE_SCORE,
122135
query_heads_share_selected_kv = QUERY_HEADS_SHARE_SELECTION
123136
)
124137
).cuda()

0 commit comments

Comments
 (0)