|
26 | 26 | VALIDATE_EVERY = 100 |
27 | 27 | PRIME_LENGTH = 64 |
28 | 28 | GENERATE_EVERY = 500 |
29 | | -GENERATE_LENGTH = 256 |
30 | | -SEQ_LEN = 256 |
| 29 | +GENERATE_LENGTH = 512 |
| 30 | +SEQ_LEN = 512 |
| 31 | +HEAD = 8 |
| 32 | +KV_HEADS = 4 |
31 | 33 |
|
32 | | -USE_SPARSE_ATTN = True |
| 34 | +USE_SPARSE_ATTN = False |
33 | 35 | USE_FLEX_FOR_FINE_SELECTION = True # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels |
34 | 36 | QUERY_HEADS_SHARE_SELECTION = False # if set to False, each query head can look at a different segment of their corresponding key / value head in GQA |
35 | 37 |
|
| 38 | +# sparse attention related |
| 39 | + |
| 40 | +SLIDING_WINDOW_SIZE = 64 |
| 41 | +COMPRESS_BLOCK_SIZE = 64 |
| 42 | + |
| 43 | +FINE_BLOCK_SIZE = 32 |
| 44 | +NUM_FINE_SELECTED = 0 |
| 45 | + |
| 46 | +INTERPOLATED_IMPORTANCE_SCORE = True |
| 47 | +USE_DIFF_TOPK = True |
| 48 | + |
36 | 49 | # experiment related |
37 | 50 |
|
38 | 51 | PROJECT_NAME = 'native-sparse-attention' |
39 | | -RUN_NAME = 'baseline' if not USE_SPARSE_ATTN else 'sparse-attn' |
40 | | -WANDB_ONLINE = False # turn this on to pipe experiment to cloud |
| 52 | +RUN_NAME = 'baseline' if not USE_SPARSE_ATTN else f'sparse-attn: compress size {COMPRESS_BLOCK_SIZE} | fine size {FINE_BLOCK_SIZE} | {NUM_FINE_SELECTED} selected' |
| 53 | +WANDB_ONLINE = True # turn this on to pipe experiment to cloud |
41 | 54 |
|
42 | 55 | # helpers |
43 | 56 |
|
@@ -101,24 +114,24 @@ def base_decoding( |
101 | 114 | num_tokens = 256, |
102 | 115 | dim = 512, |
103 | 116 | depth = 6, |
104 | | - heads = 8, |
| 117 | + heads = HEADS, |
105 | 118 | dim_head = 64, |
106 | | - kv_heads = 4, |
| 119 | + kv_heads = KV_HEADS, |
107 | 120 | use_sparse_attn = USE_SPARSE_ATTN, |
108 | 121 | use_flex_sliding_window = True, |
109 | 122 | use_flex_fine_selection = USE_FLEX_FOR_FINE_SELECTION, |
110 | 123 | sparse_attn_kwargs = dict( |
111 | | - sliding_window_size = 32, |
112 | | - compress_block_size = 32, |
| 124 | + sliding_window_size = SLIDING_WINDOW_SIZE, |
| 125 | + compress_block_size = COMPRESS_BLOCK_SIZE, |
113 | 126 | compress_mlp = GroupedMLP( |
114 | 127 | dim_head = 64, |
115 | | - compress_block_size = 32, |
116 | | - heads= 4, |
| 128 | + compress_block_size = COMPRESS_BLOCK_SIZE, |
| 129 | + heads = KV_HEADS, |
117 | 130 | ), |
118 | | - selection_block_size = 32, |
119 | | - num_selected_blocks = 2, |
120 | | - use_diff_topk = True, |
121 | | - interpolated_importance_score = True, |
| 131 | + selection_block_size = FINE_BLOCK_SIZE, |
| 132 | + num_selected_blocks = NUM_FINE_SELECTED, |
| 133 | + use_diff_topk = USE_DIFF_TOPK, |
| 134 | + interpolated_importance_score = INTERPOLATED_IMPORTANCE_SCORE, |
122 | 135 | query_heads_share_selected_kv = QUERY_HEADS_SHARE_SELECTION |
123 | 136 | ) |
124 | 137 | ).cuda() |
|
0 commit comments