File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change 1111
1212from native_sparse_attention_pytorch .transformer import Transformer
1313
14+ from native_sparse_attention_pytorch .compress_networks import (
15+ ConvLinearCompress ,
16+ AttentionPool ,
17+ GroupedMLP
18+ )
19+
1420# constants
1521
1622NUM_BATCHES = int (1e5 )
@@ -95,13 +101,19 @@ def base_decoding(
95101 dim = 512 ,
96102 depth = 6 ,
97103 heads = 8 ,
104+ dim_head = 64 ,
98105 kv_heads = 4 ,
99106 use_sparse_attn = USE_SPARSE_ATTN ,
100107 use_flex_sliding_window = True ,
101108 use_flex_fine_selection = USE_FLEX_FOR_FINE_SELECTION ,
102109 sparse_attn_kwargs = dict (
103110 sliding_window_size = 32 ,
104111 compress_block_size = 32 ,
112+ compress_mlp = GroupedMLP (
113+ dim_head = 64 ,
114+ compress_block_size = 32 ,
115+ heads = 4 ,
116+ ),
105117 selection_block_size = 32 ,
106118 num_selected_blocks = 2 ,
107119 use_diff_topk = False ,
You can’t perform that action at this time.
0 commit comments