Skip to content

Commit 4d20599

Browse files
committed
use n_swa_pattern from llama.cpp #12373
1 parent f6b3831 commit 4d20599

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

llama.cpp/llama.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,6 +1957,7 @@ struct llama_hparams {
19571957
uint32_t n_layer;
19581958
uint32_t n_rot;
19591959
uint32_t n_swa = 0; // sliding window attention (SWA)
1960+
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
19601961
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
19611962
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
19621963
uint32_t n_expert = 0;
@@ -2025,6 +2026,7 @@ struct llama_hparams {
20252026
if (this->n_layer != other.n_layer) return true;
20262027
if (this->n_rot != other.n_rot) return true;
20272028
if (this->n_swa != other.n_swa) return true;
2029+
if (this->n_swa_pattern != other.n_swa_pattern) return true;
20282030
if (this->n_embd_head_k != other.n_embd_head_k) return true;
20292031
if (this->n_embd_head_v != other.n_embd_head_v) return true;
20302032
if (this->n_expert != other.n_expert) return true;
@@ -4809,6 +4811,8 @@ static void llm_load_hparams(
48094811
case LLM_ARCH_GEMMA2:
48104812
{
48114813
hparams.n_swa = 4096; // default value of gemma 2
4814+
hparams.n_swa_pattern = 2;
4815+
48124816
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
48134817
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
48144818
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
@@ -4824,12 +4828,12 @@ static void llm_load_hparams(
48244828
} break;
48254829
case LLM_ARCH_GEMMA3:
48264830
{
4827-
hparams.n_swa = 1024;
4831+
hparams.n_swa_pattern = 6;
48284832

48294833
hparams.rope_freq_base_train_swa = 10000.0f;
48304834
hparams.rope_freq_scale_train_swa = 1.0f;
48314835

4832-
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
4836+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
48334837
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
48344838

48354839
switch (hparams.n_layer) {
@@ -11811,7 +11815,8 @@ struct llm_build_context {
1181111815

1181211816
for (int il = 0; il < n_layer; ++il) {
1181311817
// (il % 2) layers use SWA
11814-
struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
11818+
const bool is_swa = il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1);
11819+
struct ggml_tensor * KQ_mask_l = is_swa ? KQ_mask_swa : KQ_mask;
1181511820

1181611821
// norm
1181711822
cur = llm_build_norm(ctx0, inpL, hparams,
@@ -11944,7 +11949,7 @@ struct llm_build_context {
1194411949
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
1194511950

1194611951
for (int il = 0; il < n_layer; ++il) {
11947-
const bool is_swa = il % 6 < 5;
11952+
const bool is_swa = il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1);
1194811953

1194911954
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
1195011955
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;

0 commit comments

Comments
 (0)