Skip to content

Commit 671534a

Browse files
committed
setup for students
1 parent e40b35f commit 671534a

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

train.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111

1212
from native_sparse_attention_pytorch.transformer import Transformer
1313

14+
from native_sparse_attention_pytorch.compress_networks import (
15+
ConvLinearCompress,
16+
AttentionPool,
17+
GroupedMLP
18+
)
19+
1420
# constants
1521

1622
NUM_BATCHES = int(1e5)
@@ -95,13 +101,19 @@ def base_decoding(
95101
dim = 512,
96102
depth = 6,
97103
heads = 8,
104+
dim_head = 64,
98105
kv_heads = 4,
99106
use_sparse_attn = USE_SPARSE_ATTN,
100107
use_flex_sliding_window = True,
101108
use_flex_fine_selection = USE_FLEX_FOR_FINE_SELECTION,
102109
sparse_attn_kwargs = dict(
103110
sliding_window_size = 32,
104111
compress_block_size = 32,
112+
compress_mlp = GroupedMLP(
113+
dim_head = 64,
114+
compress_block_size = 32,
115+
heads= 4,
116+
),
105117
selection_block_size = 32,
106118
num_selected_blocks = 2,
107119
use_diff_topk = False,

0 commit comments

Comments
 (0)