@@ -1957,6 +1957,7 @@ struct llama_hparams {
19571957 uint32_t n_layer;
19581958 uint32_t n_rot;
19591959 uint32_t n_swa = 0; // sliding window attention (SWA)
1960+ uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
19601961 uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
19611962 uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
19621963 uint32_t n_expert = 0;
@@ -2025,6 +2026,7 @@ struct llama_hparams {
20252026 if (this->n_layer != other.n_layer) return true;
20262027 if (this->n_rot != other.n_rot) return true;
20272028 if (this->n_swa != other.n_swa) return true;
2029+ if (this->n_swa_pattern != other.n_swa_pattern) return true;
20282030 if (this->n_embd_head_k != other.n_embd_head_k) return true;
20292031 if (this->n_embd_head_v != other.n_embd_head_v) return true;
20302032 if (this->n_expert != other.n_expert) return true;
@@ -4809,6 +4811,8 @@ static void llm_load_hparams(
48094811 case LLM_ARCH_GEMMA2:
48104812 {
48114813 hparams.n_swa = 4096; // default value of gemma 2
4814+ hparams.n_swa_pattern = 2;
4815+
48124816 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
48134817 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
48144818 ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
@@ -4824,12 +4828,12 @@ static void llm_load_hparams(
48244828 } break;
48254829 case LLM_ARCH_GEMMA3:
48264830 {
4827- hparams.n_swa = 1024 ;
4831+ hparams.n_swa_pattern = 6 ;
48284832
48294833 hparams.rope_freq_base_train_swa = 10000.0f;
48304834 hparams.rope_freq_scale_train_swa = 1.0f;
48314835
4832- ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false );
4836+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
48334837 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
48344838
48354839 switch (hparams.n_layer) {
@@ -11811,7 +11815,8 @@ struct llm_build_context {
1181111815
1181211816 for (int il = 0; il < n_layer; ++il) {
1181311817 // (il % 2) layers use SWA
11814- struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
11818+ const bool is_swa = il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1);
11819+ struct ggml_tensor * KQ_mask_l = is_swa ? KQ_mask_swa : KQ_mask;
1181511820
1181611821 // norm
1181711822 cur = llm_build_norm(ctx0, inpL, hparams,
@@ -11944,7 +11949,7 @@ struct llm_build_context {
1194411949 struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
1194511950
1194611951 for (int il = 0; il < n_layer; ++il) {
11947- const bool is_swa = il % 6 < 5 ;
11952+ const bool is_swa = il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1) ;
1194811953
1194911954 const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
1195011955 const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
0 commit comments