Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,10 +957,14 @@ bool llama_kv_cache::get_has_shift() const {
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;

// pad the n_kv value so that the graph remains constant across batches and can be reused
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
const uint32_t n_pad_cur = std::max(n_pad, 256u);

for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
const auto & cells = v_cells[sinfo.strm[s]];

result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
}

return result;
Expand Down Expand Up @@ -2010,8 +2014,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}

uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;
}
2 changes: 0 additions & 2 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ struct llama_context;

class llama_kv_cache : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);

struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
Expand Down
23 changes: 4 additions & 19 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19641,7 +19641,7 @@ struct llm_build_apertus : public llm_graph_context {
}
};

llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const {
llama_memory_i * res;

switch (arch) {
Expand Down Expand Up @@ -19692,17 +19692,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
};
}

const auto padding = llama_kv_cache::get_padding(cparams);

cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);

res = new llama_memory_hybrid(
/* model */ *this,
/* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding,
/* attn_n_pad */ 1,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
Expand All @@ -19714,23 +19710,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
const auto padding = llama_kv_cache::get_padding(cparams);

uint32_t n_ctx_per_stream = cparams.n_ctx;

if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);

cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
} else {
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);

cparams.n_ctx = n_ctx_per_stream;
}

LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);

llama_memory_i::layer_reuse_cb reuse = nullptr;

if (arch == LLM_ARCH_GEMMA3N) {
Expand All @@ -19757,7 +19742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
n_ctx_per_stream,
cparams.n_seq_max,
cparams.n_ubatch,
padding,
1,
nullptr,
reuse);
} else {
Expand All @@ -19772,7 +19757,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_seq_max,
padding,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,
Expand Down
3 changes: 1 addition & 2 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,8 @@ struct llama_model {

ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;

// note: can mutate `cparams`
// TODO: move this to new llm_arch_model_i interface
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;

// TODO: move this to new llm_arch_model_i interface
ggml_cgraph * build_graph(const llm_graph_params & params) const;
Expand Down
8 changes: 3 additions & 5 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2946,17 +2946,15 @@ struct server_context {
SLT_DBG(slot, "%s", "stopped by EOS\n");
}

const auto n_ctx_train = llama_model_n_ctx_train(model);

if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= slot.n_ctx) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old logic of using the training context as a generation limit seems dubious - the context size of the slot should impose the limit.

After #16736, the slot.n_ctx will be capped to the training context either way, so this change should not make a really big difference either way.

Copy link
Collaborator

@ngxson ngxson Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, do you think this code branch can be removed altogether? The n_ctx cap is already imposed on the slot.n_past >= slot.n_ctx condition above (L2933)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point.

slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; // stop prediction

SLT_WRN(slot,
"n_predict (%d) is set for infinite generation. "
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
slot.task->params.n_predict, n_ctx_train);
"Limiting generated tokens to slot.n_ctx (%d) to avoid EOS-less generation infinite loop\n",
slot.task->params.n_predict, slot.n_ctx);
}

SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_ctx_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_ctx_shift_enabled():

@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False),
(-1, 120, True),
(-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server
Expand Down
Loading