Skip to content

Commit be3bba6

Browse files
committed
Merge branch 'upstream' into concedo_experimental
# Conflicts: # src/llama-model.cpp
2 parents 782e1e1 + c522ce4 commit be3bba6

File tree

6 files changed

+79
-78
lines changed

6 files changed

+79
-78
lines changed

src/llama-context.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -537,16 +537,12 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
537537
const int64_t n_head_kv = hparams.n_head_kv(il);
538538
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
539539

540-
float freq_base_l = cparams.rope_freq_base;
541-
float freq_scale_l = cparams.rope_freq_scale;
540+
const bool is_swa = hparams.is_swa(il);
542541

543-
// TODO: improve
544-
if (model.arch == LLM_ARCH_GEMMA3) {
545-
const bool is_sliding = hparams.is_sliding(il);
546-
547-
freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base;
548-
freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale;
549-
}
542+
// note: the swa rope params could become part of the cparams in the future
543+
// if we decide to make them configurable, like the non-sliding ones
544+
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
545+
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
550546

551547
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
552548

src/llama-graph.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,29 +1311,23 @@ ggml_tensor * llm_graph_context::build_attn(
13111311
return cur;
13121312
}
13131313

1314-
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(
1315-
bool causal,
1316-
bool swa) const {
1314+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
13171315
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
13181316

13191317
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
13201318

13211319
const auto n_kv = kv_self->n;
13221320

1323-
inp->self_kq_mask = causal
1324-
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
1325-
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1321+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13261322
//cb(inp->self_kq_mask, "KQ_mask", -1);
13271323
ggml_set_input(inp->self_kq_mask);
13281324

13291325
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13301326

1331-
if (swa) {
1327+
if (hparams.n_swa_pattern > 1) {
13321328
GGML_ASSERT(hparams.n_swa > 0);
13331329

1334-
inp->self_kq_mask_swa = causal
1335-
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
1336-
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1330+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13371331
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
13381332
ggml_set_input(inp->self_kq_mask_swa);
13391333

@@ -1403,9 +1397,9 @@ ggml_tensor * llm_graph_context::build_attn(
14031397
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
14041398
}
14051399

1406-
const bool is_sliding = hparams.is_sliding(il);
1400+
const bool is_swa = hparams.is_swa(il);
14071401

1408-
const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1402+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
14091403

14101404
const auto n_kv = kv_self->n;
14111405

src/llama-graph.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,7 @@ struct llm_graph_context {
509509
float kq_scale,
510510
int il) const;
511511

512-
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(
513-
bool causal,
514-
bool swa) const;
512+
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
515513

516514
ggml_tensor * build_attn(
517515
llm_graph_input_attn_kv_unified * inp,

src/llama-hparams.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
7070
return ssm_d_state * ssm_d_inner;
7171
}
7272

73-
bool llama_hparams::is_sliding(uint32_t il) const {
73+
bool llama_hparams::is_swa(uint32_t il) const {
7474
if (il < n_layer) {
7575
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
7676
}

src/llama-hparams.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ struct llama_hparams {
7979

8080
float rope_attn_factor = 1.0f;
8181
float rope_freq_base_train;
82+
float rope_freq_base_train_swa;
8283
float rope_freq_scale_train;
84+
float rope_freq_scale_train_swa;
8385
uint32_t n_ctx_orig_yarn;
8486
float rope_yarn_log_mul;
8587

@@ -135,7 +137,7 @@ struct llama_hparams {
135137
// dimension of the recurrent state embeddings
136138
uint32_t n_embd_v_s() const;
137139

138-
bool is_sliding(uint32_t il) const;
140+
bool is_swa(uint32_t il) const;
139141
};
140142

141143
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

0 commit comments

Comments
 (0)