Skip to content

Commit 85a7d86

Browse files
authored
memory : remove KV cache size padding (#16812)
* memory : remove KV cache size padding * cont : restore padding for n_kv tensor shape * server : use slot context size instead of training context size * server : simplify context limit logic
1 parent a8ca18b commit 85a7d86

File tree

6 files changed

+14
-54
lines changed

6 files changed

+14
-54
lines changed

src/llama-kv-cache.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -961,10 +961,14 @@ bool llama_kv_cache::get_has_shift() const {
961961
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
962962
uint32_t result = 0;
963963

964+
// pad the n_kv value so that the graph remains constant across batches and can be reused
965+
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
966+
const uint32_t n_pad_cur = std::max(n_pad, 256u);
967+
964968
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
965969
const auto & cells = v_cells[sinfo.strm[s]];
966970

967-
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
971+
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
968972
}
969973

970974
return result;
@@ -2014,8 +2018,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
20142018
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
20152019
kv->set_input_pos_bucket(dst, ubatch);
20162020
}
2017-
2018-
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
2019-
// the FA kernels require padding to avoid extra runtime boundary checks
2020-
return cparams.flash_attn ? 256u : 32u;
2021-
}

src/llama-kv-cache.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ struct llama_context;
1919

2020
class llama_kv_cache : public llama_memory_i {
2121
public:
22-
static uint32_t get_padding(const llama_cparams & cparams);
23-
2422
struct stream_copy_info {
2523
bool empty() const {
2624
assert(ssrc.size() == sdst.size());

src/llama-model.cpp

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19641,7 +19641,7 @@ struct llm_build_apertus : public llm_graph_context {
1964119641
}
1964219642
};
1964319643

19644-
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
19644+
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const {
1964519645
llama_memory_i * res;
1964619646

1964719647
switch (arch) {
@@ -19692,17 +19692,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1969219692
};
1969319693
}
1969419694

19695-
const auto padding = llama_kv_cache::get_padding(cparams);
19696-
19697-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
19698-
1969919695
res = new llama_memory_hybrid(
1970019696
/* model */ *this,
1970119697
/* attn_type_k */ params.type_k,
1970219698
/* attn_type_v */ params.type_v,
1970319699
/* attn_v_trans */ !cparams.flash_attn,
1970419700
/* attn_kv_size */ cparams.n_ctx,
19705-
/* attn_n_pad */ padding,
19701+
/* attn_n_pad */ 1,
1970619702
/* attn_n_swa */ hparams.n_swa,
1970719703
/* attn_swa_type */ hparams.swa_type,
1970819704
/* recurrent_type_k */ GGML_TYPE_F32,
@@ -19714,23 +19710,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1971419710
/* filter_attn */ std::move(filter_attn),
1971519711
/* filter_recr */ std::move(filter_recr));
1971619712
} else {
19717-
const auto padding = llama_kv_cache::get_padding(cparams);
19718-
1971919713
uint32_t n_ctx_per_stream = cparams.n_ctx;
1972019714

1972119715
if (!cparams.kv_unified) {
1972219716
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
19723-
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
19724-
19725-
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
19726-
} else {
19727-
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
19728-
19729-
cparams.n_ctx = n_ctx_per_stream;
1973019717
}
1973119718

19732-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
19733-
1973419719
llama_memory_i::layer_reuse_cb reuse = nullptr;
1973519720

1973619721
if (arch == LLM_ARCH_GEMMA3N) {
@@ -19757,7 +19742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1975719742
n_ctx_per_stream,
1975819743
cparams.n_seq_max,
1975919744
cparams.n_ubatch,
19760-
padding,
19745+
1,
1976119746
nullptr,
1976219747
reuse);
1976319748
} else {
@@ -19772,7 +19757,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1977219757
cparams.kv_unified,
1977319758
n_ctx_per_stream,
1977419759
cparams.n_seq_max,
19775-
padding,
19760+
1,
1977619761
hparams.n_swa,
1977719762
hparams.swa_type,
1977819763
nullptr,

src/llama-model.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,9 +500,8 @@ struct llama_model {
500500

501501
ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
502502

503-
// note: can mutate `cparams`
504503
// TODO: move this to new llm_arch_model_i interface
505-
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
504+
llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;
506505

507506
// TODO: move this to new llm_arch_model_i interface
508507
ggml_cgraph * build_graph(const llm_graph_params & params) const;

tools/server/server.cpp

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2866,10 +2866,12 @@ struct server_context {
28662866

28672867
// if context shifting is disabled, make sure that we don't run out of context
28682868
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
2869+
slot.truncated = true;
28692870
slot.stop = STOP_TYPE_LIMIT;
28702871
slot.has_next_token = false;
28712872

2872-
SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx);
2873+
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
2874+
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
28732875
}
28742876

28752877
// check the limits
@@ -2929,36 +2931,13 @@ struct server_context {
29292931
}
29302932
}
29312933

2932-
// if context shift is disabled, we stop when it reaches the context limit
2933-
if (slot.n_past >= slot.n_ctx) {
2934-
slot.truncated = true;
2935-
slot.stop = STOP_TYPE_LIMIT;
2936-
slot.has_next_token = false;
2937-
2938-
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
2939-
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
2940-
}
2941-
29422934
if (llama_vocab_is_eog(vocab, result.tok)) {
29432935
slot.stop = STOP_TYPE_EOS;
29442936
slot.has_next_token = false;
29452937

29462938
SLT_DBG(slot, "%s", "stopped by EOS\n");
29472939
}
29482940

2949-
const auto n_ctx_train = llama_model_n_ctx_train(model);
2950-
2951-
if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
2952-
slot.truncated = true;
2953-
slot.stop = STOP_TYPE_LIMIT;
2954-
slot.has_next_token = false; // stop prediction
2955-
2956-
SLT_WRN(slot,
2957-
"n_predict (%d) is set for infinite generation. "
2958-
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
2959-
slot.task->params.n_predict, n_ctx_train);
2960-
}
2961-
29622941
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
29632942

29642943
return slot.has_next_token; // continue

tools/server/tests/unit/test_ctx_shift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_ctx_shift_enabled():
4545

4646
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
4747
(64, 64, False),
48-
(-1, 120, True),
48+
(-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
4949
])
5050
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
5151
global server

0 commit comments

Comments
 (0)