Skip to content

Commit 850d301

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 93cf1e4 + 8fcb563 commit 850d301

File tree

11 files changed

+104
-81
lines changed

11 files changed

+104
-81
lines changed

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,8 @@ struct common_init_result common_init_from_params(common_params & params) {
10331033
if (params.warmup) {
10341034
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
10351035

1036+
llama_set_warmup(lctx, true);
1037+
10361038
std::vector<llama_token> tmp;
10371039
llama_token bos = llama_vocab_bos(vocab);
10381040
llama_token eos = llama_vocab_eos(vocab);
@@ -1063,6 +1065,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10631065
llama_kv_self_clear(lctx);
10641066
llama_synchronize(lctx);
10651067
llama_perf_context_reset(lctx);
1068+
llama_set_warmup(lctx, false);
10661069
}
10671070

10681071
iparams.model.reset(model);

examples/server/utils.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,9 @@ static json oaicompat_completion_params_parse(
621621

622622
llama_params["chat_format"] = static_cast<int>(chat_params.format);
623623
llama_params["prompt"] = chat_params.prompt;
624-
llama_params["grammar"] = chat_params.grammar;
624+
if (!chat_params.grammar.empty()) {
625+
llama_params["grammar"] = chat_params.grammar;
626+
}
625627
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
626628
auto grammar_triggers = json::array();
627629
for (const auto & trigger : chat_params.grammar_triggers) {

include/llama.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,10 @@ extern "C" {
945945
// If set to true, the model will only attend to the past tokens
946946
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
947947

948+
// Set whether the model is in warmup mode or not
949+
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
950+
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
951+
948952
// Set abort callback
949953
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
950954

src/llama-context.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ llama_context::llama_context(
3939
cparams.flash_attn = params.flash_attn;
4040
cparams.no_perf = params.no_perf;
4141
cparams.pooling_type = params.pooling_type;
42+
cparams.warmup = false;
4243

4344
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
4445
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -537,16 +538,12 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
537538
const int64_t n_head_kv = hparams.n_head_kv(il);
538539
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
539540

540-
float freq_base_l = cparams.rope_freq_base;
541-
float freq_scale_l = cparams.rope_freq_scale;
541+
const bool is_swa = hparams.is_swa(il);
542542

543-
// TODO: improve
544-
if (model.arch == LLM_ARCH_GEMMA3) {
545-
const bool is_sliding = hparams.is_sliding(il);
546-
547-
freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base;
548-
freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale;
549-
}
543+
// note: the swa rope params could become part of the cparams in the future
544+
// if we decide to make them configurable, like the non-sliding ones
545+
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
546+
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
550547

551548
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
552549

@@ -952,6 +949,12 @@ void llama_context::set_causal_attn(bool value) {
952949
cparams.causal_attn = value;
953950
}
954951

952+
void llama_context::set_warmup(bool value) {
953+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
954+
955+
cparams.warmup = value;
956+
}
957+
955958
void llama_context::set_adapter_lora(
956959
llama_adapter_lora * adapter,
957960
float scale) {
@@ -1598,7 +1601,7 @@ void llama_context::output_reorder() {
15981601
//
15991602

16001603
int32_t llama_context::graph_max_nodes() const {
1601-
return std::max<int32_t>(8192, 5*model.n_tensors());
1604+
return std::max<int32_t>(65536, 5*model.n_tensors());
16021605
}
16031606

16041607
ggml_cgraph * llama_context::graph_init() {
@@ -2376,6 +2379,10 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
23762379
ctx->set_causal_attn(causal_attn);
23772380
}
23782381

2382+
void llama_set_warmup(llama_context * ctx, bool warmup) {
2383+
ctx->set_warmup(warmup);
2384+
}
2385+
23792386
void llama_synchronize(llama_context * ctx) {
23802387
ctx->synchronize();
23812388
}

src/llama-context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ struct llama_context {
6464

6565
void set_embeddings (bool value);
6666
void set_causal_attn(bool value);
67+
void set_warmup(bool value);
6768

6869
void set_adapter_lora(
6970
llama_adapter_lora * adapter,

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct llama_cparams {
2929
bool offload_kqv;
3030
bool flash_attn;
3131
bool no_perf;
32+
bool warmup;
3233

3334
enum llama_pooling_type pooling_type;
3435

src/llama-graph.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
577577
n_embd_head_v (hparams.n_embd_head_v),
578578
n_embd_v_gqa (hparams.n_embd_v_gqa()),
579579
n_expert (hparams.n_expert),
580-
n_expert_used (hparams.n_expert_used),
580+
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
581581
freq_base (cparams.rope_freq_base),
582582
freq_scale (cparams.rope_freq_scale),
583583
ext_factor (cparams.yarn_ext_factor),
@@ -1311,29 +1311,23 @@ ggml_tensor * llm_graph_context::build_attn(
13111311
return cur;
13121312
}
13131313

1314-
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(
1315-
bool causal,
1316-
bool swa) const {
1314+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
13171315
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
13181316

13191317
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
13201318

13211319
const auto n_kv = kv_self->n;
13221320

1323-
inp->self_kq_mask = causal
1324-
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
1325-
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1321+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13261322
//cb(inp->self_kq_mask, "KQ_mask", -1);
13271323
ggml_set_input(inp->self_kq_mask);
13281324

13291325
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13301326

1331-
if (swa) {
1327+
if (hparams.n_swa_pattern > 1) {
13321328
GGML_ASSERT(hparams.n_swa > 0);
13331329

1334-
inp->self_kq_mask_swa = causal
1335-
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
1336-
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1330+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13371331
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
13381332
ggml_set_input(inp->self_kq_mask_swa);
13391333

@@ -1403,9 +1397,9 @@ ggml_tensor * llm_graph_context::build_attn(
14031397
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
14041398
}
14051399

1406-
const bool is_sliding = hparams.is_sliding(il);
1400+
const bool is_swa = hparams.is_swa(il);
14071401

1408-
const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1402+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
14091403

14101404
const auto n_kv = kv_self->n;
14111405

src/llama-graph.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,7 @@ struct llm_graph_context {
509509
float kq_scale,
510510
int il) const;
511511

512-
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(
513-
bool causal,
514-
bool swa) const;
512+
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
515513

516514
ggml_tensor * build_attn(
517515
llm_graph_input_attn_kv_unified * inp,

src/llama-hparams.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
7070
return ssm_d_state * ssm_d_inner;
7171
}
7272

73-
bool llama_hparams::is_sliding(uint32_t il) const {
73+
bool llama_hparams::is_swa(uint32_t il) const {
7474
if (il < n_layer) {
7575
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
7676
}

src/llama-hparams.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ struct llama_hparams {
7979

8080
float rope_attn_factor = 1.0f;
8181
float rope_freq_base_train;
82+
float rope_freq_base_train_swa;
8283
float rope_freq_scale_train;
84+
float rope_freq_scale_train_swa;
8385
uint32_t n_ctx_orig_yarn;
8486
float rope_yarn_log_mul;
8587

@@ -135,7 +137,7 @@ struct llama_hparams {
135137
// dimension of the recurrent state embeddings
136138
uint32_t n_embd_v_s() const;
137139

138-
bool is_sliding(uint32_t il) const;
140+
bool is_swa(uint32_t il) const;
139141
};
140142

141143
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

0 commit comments

Comments
 (0)