Skip to content

Commit 7143840

Browse files
committed
overload set_swa_pattern for modern bert
1 parent eab776e commit 7143840

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

src/llama-hparams.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
88
}
99
}
1010

11+
void llama_hparams::set_swa_pattern(uint32_t n_pattern, uint32_t remainder) {
12+
for (uint32_t il = 0; il < n_layer; ++il) {
13+
swa_layers[il] = n_pattern == 0 || (il % n_pattern != remainder);
14+
}
15+
}
16+
1117
bool llama_hparams::is_swa_any() const {
1218
for (uint32_t il = 0; il < n_layer; ++il) {
1319
if (swa_layers[il]) {

src/llama-hparams.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,18 @@ struct llama_hparams {
162162
// etc ...
163163
void set_swa_pattern(uint32_t n_pattern);
164164

165+
// Overload that allows specifying which position in the pattern is dense
166+
// The remainder parameter specifies which position in the pattern is dense
167+
// example: n_pattern = 3, remainder = 2
168+
// il == 0: swa (0 % 3 = 0, which is not equal to 2)
169+
// il == 1: swa (1 % 3 = 1, which is not equal to 2)
170+
// il == 2: dense (2 % 3 = 2, which equals 2)
171+
// il == 3: swa (3 % 3 = 0, which is not equal to 2)
172+
// il == 4: swa (4 % 3 = 1, which is not equal to 2)
173+
// il == 5: dense (5 % 3 = 2, which equals 2)
174+
// etc ...
175+
void set_swa_pattern(uint32_t n_pattern, uint32_t remainder);
176+
165177
// return true if one of the layers is SWA
166178
bool is_swa_any() const;
167179

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
714714
hparams.n_swa = 128;
715715

716716
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
717-
hparams.set_swa_pattern(3);
717+
hparams.set_swa_pattern(3, 0);
718718

719719
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
720720
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);

0 commit comments

Comments
 (0)