Skip to content

Commit 93cf1e4

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 22d35ac + 84d5475 commit 93cf1e4

File tree

6 files changed

+37
-43
lines changed

6 files changed

+37
-43
lines changed

src/llama-context.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,10 +442,10 @@ ggml_tensor * llama_context::build_rope_shift(
442442
ggml_tensor * cur,
443443
ggml_tensor * shift,
444444
ggml_tensor * factors,
445+
float freq_base,
446+
float freq_scale,
445447
ggml_backend_buffer * bbuf) const {
446448
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
447-
const auto & freq_base = cparams.rope_freq_base;
448-
const auto & freq_scale = cparams.rope_freq_scale;
449449

450450
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
451451
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
@@ -537,6 +537,17 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
537537
const int64_t n_head_kv = hparams.n_head_kv(il);
538538
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
539539

540+
float freq_base_l = cparams.rope_freq_base;
541+
float freq_scale_l = cparams.rope_freq_scale;
542+
543+
// TODO: improve
544+
if (model.arch == LLM_ARCH_GEMMA3) {
545+
const bool is_sliding = hparams.is_sliding(il);
546+
547+
freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base;
548+
freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale;
549+
}
550+
540551
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
541552

542553
ggml_tensor * k =
@@ -546,7 +557,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
546557
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
547558
0);
548559

549-
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer);
560+
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
550561

551562
ggml_build_forward_expand(gf, cur);
552563
}

src/llama-context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ struct llama_context {
168168
ggml_tensor * cur,
169169
ggml_tensor * shift,
170170
ggml_tensor * factors,
171+
float freq_base,
172+
float freq_scale,
171173
ggml_backend_buffer * bbuf) const;
172174

173175
llm_graph_result_ptr build_kv_self_shift(

src/llama-graph.cpp

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,34 +1403,7 @@ ggml_tensor * llm_graph_context::build_attn(
14031403
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
14041404
}
14051405

1406-
// TODO: improve
1407-
bool is_sliding = false;
1408-
1409-
switch (arch) {
1410-
case LLM_ARCH_COHERE2:
1411-
{
1412-
const int32_t sliding_window_pattern = 4;
1413-
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
1414-
} break;
1415-
case LLM_ARCH_GEMMA2:
1416-
{
1417-
const int32_t sliding_window_pattern = 2;
1418-
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
1419-
} break;
1420-
case LLM_ARCH_GEMMA3:
1421-
{
1422-
const int32_t sliding_window_pattern = 6;
1423-
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
1424-
} break;
1425-
case LLM_ARCH_PHI3:
1426-
{
1427-
is_sliding = hparams.n_swa > 0;
1428-
} break;
1429-
default:
1430-
{
1431-
is_sliding = false;
1432-
}
1433-
};
1406+
const bool is_sliding = hparams.is_sliding(il);
14341407

14351408
const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
14361409

src/llama-hparams.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
6969
// corresponds to Mamba's ssm_states size
7070
return ssm_d_state * ssm_d_inner;
7171
}
72+
73+
bool llama_hparams::is_sliding(uint32_t il) const {
74+
if (il < n_layer) {
75+
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
76+
}
77+
78+
GGML_ABORT("fatal error");
79+
}

src/llama-hparams.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ struct llama_hparams {
3636
uint32_t n_layer;
3737
uint32_t n_rot;
3838
uint32_t n_swa = 0; // sliding window attention (SWA)
39+
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
3940
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
4041
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
4142
uint32_t n_expert = 0;
@@ -133,6 +134,8 @@ struct llama_hparams {
133134

134135
// dimension of the recurrent state embeddings
135136
uint32_t n_embd_v_s() const;
137+
138+
bool is_sliding(uint32_t il) const;
136139
};
137140

138141
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

src/llama-model.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -859,11 +859,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
859859
case LLM_ARCH_GEMMA2:
860860
{
861861
hparams.n_swa = 4096; // default value of gemma 2
862+
hparams.n_swa_pattern = 2;
863+
hparams.attn_soft_cap = true;
864+
862865
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
863866
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
864867
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
865868
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
866-
hparams.attn_soft_cap = true;
867869

868870
switch (hparams.n_layer) {
869871
case 26: type = LLM_TYPE_2B; break;
@@ -874,6 +876,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
874876
} break;
875877
case LLM_ARCH_GEMMA3:
876878
{
879+
hparams.n_swa_pattern = 6;
880+
877881
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
878882
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
879883

@@ -953,6 +957,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
953957
} break;
954958
case LLM_ARCH_COHERE2:
955959
{
960+
hparams.n_swa_pattern = 4;
961+
956962
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
957963
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
958964
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -7413,12 +7419,8 @@ struct llm_build_gemma3 : public llm_graph_context {
74137419
// TODO: is causal == true correct? might need some changes
74147420
auto * inp_attn = build_attn_inp_kv_unified(true, true);
74157421

7416-
// "5-to-1 interleaved attention"
7417-
// 5 layers of local attention followed by 1 layer of global attention
7418-
static const int sliding_window_pattern = 6;
7419-
74207422
for (int il = 0; il < n_layer; ++il) {
7421-
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
7423+
const bool is_sliding = hparams.is_sliding(il);
74227424

74237425
const float freq_base_l = is_sliding ? 10000.0f : freq_base;
74247426
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
@@ -8009,13 +8011,8 @@ struct llm_build_cohere2 : public llm_graph_context {
80098011

80108012
auto * inp_attn = build_attn_inp_kv_unified(true, true);
80118013

8012-
// sliding window switch pattern
8013-
const int32_t sliding_window_pattern = 4;
8014-
80158014
for (int il = 0; il < n_layer; ++il) {
8016-
// three layers sliding window attention (window size 4096) and ROPE
8017-
// fourth layer uses global attention without positional embeddings
8018-
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
8015+
const bool is_sliding = hparams.is_sliding(il);
80198016

80208017
// norm
80218018
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);

0 commit comments

Comments
 (0)