Skip to content

Commit 7930389

Browse files
committed
Revert some files
1 parent 206dac9 commit 7930389

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

src/llama-hparams.h

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,12 @@ struct llama_hparams {
102102

103103
// Sliding Window Attention (SWA)
104104
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
105-
106-
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
107-
uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA)
108-
// by default n == 1, all layers are dense
109-
// note that if n_swa_pattern == 0, all layers are SWA
110-
// example: n_swa_pattern = 3
111-
// il == 0: swa
112-
// il == 1: swa
113-
// il == 2: dense
114-
// il == 3: swa
115-
// il == 4: swa
116-
// il == 5: dense
117-
// il == 6: swa
118-
// etc ...
105+
// the size of the sliding window (0 - no SWA)
106+
uint32_t n_swa = 0;
107+
// if swa_layers[il] == true, then layer il is SWA
108+
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
109+
// by default, all layers are dense
110+
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
119111

120112
// for State Space Models
121113
uint32_t ssm_d_conv = 0;
@@ -153,6 +145,23 @@ struct llama_hparams {
153145
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
154146
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
155147

148+
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
149+
// note that if n_pattern == 0, all layers are SWA
150+
// if n_pattern == 1, all layers are dense
151+
// example: n_pattern = 3
152+
// il == 0: swa
153+
// il == 1: swa
154+
// il == 2: dense
155+
// il == 3: swa
156+
// il == 4: swa
157+
// il == 5: dense
158+
// il == 6: swa
159+
// etc ...
160+
void set_swa_pattern(uint32_t n_pattern);
161+
162+
// return true if one of the layers is SWA
163+
bool is_swa_any() const;
164+
156165
uint32_t n_head(uint32_t il = 0) const;
157166

158167
uint32_t n_head_kv(uint32_t il = 0) const;

0 commit comments

Comments
 (0)