File tree Expand file tree Collapse file tree 3 files changed +19
-1
lines changed Expand file tree Collapse file tree 3 files changed +19
-1
lines changed Original file line number Diff line number Diff line change @@ -8,6 +8,12 @@ void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
88 }
99}
1010
11+ void llama_hparams::set_swa_pattern (uint32_t n_pattern, uint32_t remainder) {
12+ for (uint32_t il = 0 ; il < n_layer; ++il) {
13+ swa_layers[il] = n_pattern == 0 || (il % n_pattern != remainder);
14+ }
15+ }
16+
1117bool llama_hparams::is_swa_any () const {
1218 for (uint32_t il = 0 ; il < n_layer; ++il) {
1319 if (swa_layers[il]) {
Original file line number Diff line number Diff line change @@ -162,6 +162,18 @@ struct llama_hparams {
162162 // etc ...
163163 void set_swa_pattern (uint32_t n_pattern);
164164
165+ // Overload that allows specifying which position in the pattern is dense
166+ // The remainder parameter specifies which position in the pattern is dense
167+ // example: n_pattern = 3, remainder = 2
168+ // il == 0: swa (0 % 3 = 0, which is not equal to 2)
169+ // il == 1: swa (1 % 3 = 1, which is not equal to 2)
170+ // il == 2: dense (2 % 3 = 2, which equals 2)
171+ // il == 3: swa (3 % 3 = 0, which is not equal to 2)
172+ // il == 4: swa (4 % 3 = 1, which is not equal to 2)
173+ // il == 5: dense (5 % 3 = 2, which equals 2)
174+ // etc ...
175+ void set_swa_pattern (uint32_t n_pattern, uint32_t remainder);
176+
165177 // return true if one of the layers is SWA
166178 bool is_swa_any () const ;
167179
Original file line number Diff line number Diff line change @@ -714,7 +714,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
714714 hparams.n_swa = 128;
715715
716716 hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
717- hparams.set_swa_pattern(3);
717+ hparams.set_swa_pattern(3, 0 );
718718
719719 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
720720 ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
You can’t perform that action at this time.
0 commit comments