Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,10 @@ extern "C" {
LLAMA_API bool llama_supports_gpu_offload(void);
LLAMA_API bool llama_supports_rpc (void);

// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
Expand Down Expand Up @@ -585,7 +588,7 @@ extern "C" {
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);

// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
// NOTE: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);

// Get the invocation tokens if the current lora is an alora
Expand Down
37 changes: 27 additions & 10 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,24 @@ llama_context::llama_context(
}
}

const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
if (cparams.kv_unified) {
cparams.n_ctx_seq = cparams.n_ctx;
} else {
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;

if (cparams.n_ctx_seq == 0) {
throw std::runtime_error("n_ctx_seq == 0");
}

if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
}
}

LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
Expand All @@ -125,14 +138,14 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);

if (n_ctx_per_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}

if (n_ctx_per_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}

if (!hparams.vocab_only) {
Expand Down Expand Up @@ -453,8 +466,8 @@ uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}

uint32_t llama_context::n_ctx_per_seq() const {
return cparams.n_ctx / cparams.n_seq_max;
uint32_t llama_context::n_ctx_seq() const {
return cparams.n_ctx_seq;
}

uint32_t llama_context::n_batch() const {
Expand Down Expand Up @@ -2383,6 +2396,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
return ctx->n_ctx();
}

uint32_t llama_n_ctx_seq(const llama_context * ctx) {
return ctx->n_ctx_seq();
}

uint32_t llama_n_batch(const llama_context * ctx) {
return ctx->n_batch();
}
Expand Down
10 changes: 5 additions & 5 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ struct llama_context {

ggml_backend_sched_t get_sched() const;

uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
uint32_t n_ctx() const;
uint32_t n_ctx_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;

uint32_t n_threads() const;
uint32_t n_threads_batch() const;
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_ctx_seq; // context for a single sequence
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
Expand Down
14 changes: 4 additions & 10 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6712,14 +6712,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
}

ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
const uint32_t n_ctx_seq = cparams.n_ctx_seq;

// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}

if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}

Expand Down Expand Up @@ -6795,12 +6795,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
uint32_t n_ctx_per_stream = cparams.n_ctx;

if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
}

llama_memory_i::layer_reuse_cb reuse = nullptr;

if (arch == LLM_ARCH_GEMMA3N) {
Expand All @@ -6824,7 +6818,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
Expand All @@ -6840,7 +6834,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
1,
hparams.n_swa,
Expand Down
9 changes: 8 additions & 1 deletion tests/test-thread-safety.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,14 @@ int main(int argc, char ** argv) {
}

batch = llama_batch_get_one(&token, 1);
if (llama_decode(ctx.get(), batch)) {

int ret = llama_decode(ctx.get(), batch);
if (ret == 1 && i > 0) {
LOG_INF("Context full, stopping generation.\n");
break;
}

if (ret != 0) {
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
failed.store(true);
return;
Expand Down
86 changes: 72 additions & 14 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2407,7 +2407,7 @@ struct server_context {

params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
params_dft.cache_type_k = params_base.speculative.cache_type_k;
Expand Down Expand Up @@ -2495,10 +2495,16 @@ struct server_context {
}

void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;

SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);

const int n_ctx_train = llama_model_n_ctx_train(model);

int n_ctx_slot = llama_n_ctx_seq(ctx);
if (n_ctx_slot > n_ctx_train) {
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
n_ctx_slot = n_ctx_train;
}

for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;

Expand Down Expand Up @@ -2527,7 +2533,7 @@ struct server_context {
}
}

SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);

slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
Expand Down Expand Up @@ -2699,6 +2705,39 @@ struct server_context {
return ret;
}

// return true if at least one slot has been purged
// TODO: improve logic
// - smarter decision which slot to purge (LRU or longest prompt?)
// - move slot to level 2 cache instead of removing?
// - instead of purging, try to store and resume later?
bool try_purge_idle_slots() {
bool res = false;

if (!params_base.kv_unified) {
return res;
}

for (auto & slot : slots) {
if (slot.is_processing()) {
continue;
}

if (slot.prompt.n_tokens() > 0) {
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());

llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();

res = true;

// purge slots one by one
break;
}
}

return res;
}

bool launch_slot_with_task(server_slot & slot, server_task && task) {
slot.reset();

Expand Down Expand Up @@ -3635,9 +3674,10 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);

// next, batch any pending prompts without exceeding n_batch
float alora_scale = -1.0f;
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;

// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
Expand Down Expand Up @@ -3914,8 +3954,11 @@ struct server_context {

// truncate any tokens that are beyond n_past for this slot
const llama_pos p0 = slot.prompt.tokens.pos_next();

SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);

if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);

// there is no common part left
Expand All @@ -3924,8 +3967,6 @@ struct server_context {
slot.prompt.tokens.clear();
}

SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);

// check if we should process the image
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
Expand Down Expand Up @@ -4126,6 +4167,8 @@ struct server_context {
std::string err;

if (n_batch == 1 && ret == 1) {
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
// need to remove the tokens from the current batch too
err = "Context size has been exceeded.";
}

Expand All @@ -4141,17 +4184,23 @@ struct server_context {
// TODO: handle ret == 2 (abort) when we start aborting

if (!err.empty()) {
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);

for (auto & slot : slots) {
send_error(slot, err);
slot.release();
if (slot.is_processing()) {
send_error(slot, err);
slot.release();
}
}

break;
}
}

// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
if (!try_purge_idle_slots()) {
n_batch /= 2;
}

SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this warning should be moved inside the if condition above, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe I forgot this from a discussion before, but currently in which case we need to retry with a small batch size?

Copy link
Member Author

@ggerganov ggerganov Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main case for retrying with smaller batches was back when we didn't have ggml_set_rows and we always had to search for contiguous set of cells (KV slots) inside the cache buffer to place the input batch. Now with ggml_set_rows this is no longer needed and technically, retrying with a smaller batch size almost has almost no purpose except in some rare cases.

But generally, when llama_decode returns 1, you should retry with a smaller batch.


Expand Down Expand Up @@ -4391,6 +4440,15 @@ int main(int argc, char ** argv) {
return 1;
}

// TODO: should we have a separate n_parallel parameter for the server?
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
if (params.n_parallel == 1 && params.kv_unified == false) {
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__);

params.n_parallel = 4;
params.kv_unified = true;
}
Comment on lines +4445 to +4450
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why this can't be default params in arg.h?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see if I can make it the default - I thought that some of the examples might not like it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah I didn't notice that there are multiple example all using n_parallel

In this case, maybe we can use a dedicated variable for server, like params.n_parallel_server ?

This can be useful when auto-generating the documentation for server args


common_init();

// struct that contains llama context and inference
Expand Down Expand Up @@ -4944,7 +5002,7 @@ int main(int argc, char ** argv) {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto n_prompt_tokens = inputs[i].size();
Expand Down
8 changes: 4 additions & 4 deletions tools/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[
(64, 15, False),
(64, 3, False),
(64, 1, True),
]
)
def test_return_progresssss(n_batch, batch_count, reuse_cache):
def test_return_progress(n_batch, batch_count, reuse_cache):
global server
server.n_batch = n_batch
server.n_ctx = 2048
server.n_ctx = 256
server.n_slots = 1
server.start()
def make_cmpl_request():
return server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": 10,
"messages": [
{"role": "user", "content": "This is a test" * 100},
{"role": "user", "content": "This is a test" * 10},
],
"stream": True,
"return_progress": True,
Expand Down
Loading
Loading