Skip to content

Commit 2ca720c

Browse files
committed
llama : update per-seq context computation
1 parent f3d1607 commit 2ca720c

File tree

8 files changed

+40
-36
lines changed

8 files changed

+40
-36
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ extern "C" {
462462
LLAMA_API bool llama_supports_rpc (void);
463463

464464
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
465+
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
465466
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
466467
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
467468
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);

src/llama-context.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,17 @@ llama_context::llama_context(
112112
}
113113
}
114114

115+
cparams.n_ctx_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;
116+
117+
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
118+
LLAMA_LOG_WARN("%s: capping n_ctx_seq (%u) to n_ctx_train (%u)\n", __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
119+
120+
cparams.n_ctx_seq = hparams.n_ctx_train;
121+
}
122+
115123
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
116124
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
117-
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq());
125+
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
118126
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
119127
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
120128
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
@@ -123,14 +131,14 @@ llama_context::llama_context(
123131
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
124132
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
125133

126-
if (n_ctx_per_seq() < hparams.n_ctx_train) {
127-
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
128-
__func__, n_ctx_per_seq(), hparams.n_ctx_train);
134+
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
135+
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
136+
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
129137
}
130138

131-
if (n_ctx_per_seq() > hparams.n_ctx_train) {
132-
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
133-
__func__, n_ctx_per_seq(), hparams.n_ctx_train);
139+
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
140+
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
141+
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
134142
}
135143

136144
if (!hparams.vocab_only) {
@@ -451,8 +459,8 @@ uint32_t llama_context::n_ctx() const {
451459
return cparams.n_ctx;
452460
}
453461

454-
uint32_t llama_context::n_ctx_per_seq() const {
455-
return cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;
462+
uint32_t llama_context::n_ctx_seq() const {
463+
return cparams.n_ctx_seq;
456464
}
457465

458466
uint32_t llama_context::n_batch() const {
@@ -2381,6 +2389,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
23812389
return ctx->n_ctx();
23822390
}
23832391

2392+
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
2393+
return ctx->n_ctx_seq();
2394+
}
2395+
23842396
uint32_t llama_n_batch(const llama_context * ctx) {
23852397
return ctx->n_batch();
23862398
}

src/llama-context.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ struct llama_context {
4343

4444
ggml_backend_sched_t get_sched() const;
4545

46-
uint32_t n_ctx() const;
47-
uint32_t n_ctx_per_seq() const;
48-
uint32_t n_batch() const;
49-
uint32_t n_ubatch() const;
50-
uint32_t n_seq_max() const;
46+
uint32_t n_ctx() const;
47+
uint32_t n_ctx_seq() const;
48+
uint32_t n_batch() const;
49+
uint32_t n_ubatch() const;
50+
uint32_t n_seq_max() const;
5151

5252
uint32_t n_threads() const;
5353
uint32_t n_threads_batch() const;

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
struct llama_cparams {
1010
uint32_t n_ctx; // context size used during inference
11+
uint32_t n_ctx_seq; // context for a single sequence
1112
uint32_t n_batch;
1213
uint32_t n_ubatch;
1314
uint32_t n_seq_max;

src/llama-model.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6668,14 +6668,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
66686668
}
66696669

66706670
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
6671-
const uint32_t n_ctx_per_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;
6671+
const uint32_t n_ctx_seq = cparams.n_ctx_seq;
66726672

66736673
// choose long/short freq factors based on the context size
66746674
if (layers[il].rope_freqs != nullptr) {
66756675
return layers[il].rope_freqs;
66766676
}
66776677

6678-
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
6678+
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
66796679
return layers[il].rope_long;
66806680
}
66816681

@@ -20190,12 +20190,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
2019020190
/* filter_attn */ std::move(filter_attn),
2019120191
/* filter_recr */ std::move(filter_recr));
2019220192
} else {
20193-
uint32_t n_ctx_per_stream = cparams.n_ctx;
20194-
20195-
if (!cparams.kv_unified) {
20196-
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
20197-
}
20198-
2019920193
llama_memory_i::layer_reuse_cb reuse = nullptr;
2020020194

2020120195
if (arch == LLM_ARCH_GEMMA3N) {
@@ -20219,7 +20213,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
2021920213
cparams.offload_kqv,
2022020214
params.swa_full,
2022120215
cparams.kv_unified,
20222-
n_ctx_per_stream,
20216+
cparams.n_ctx_seq,
2022320217
cparams.n_seq_max,
2022420218
cparams.n_ubatch,
2022520219
1,
@@ -20235,7 +20229,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
2023520229
!cparams.flash_attn,
2023620230
cparams.offload_kqv,
2023720231
cparams.kv_unified,
20238-
n_ctx_per_stream,
20232+
cparams.n_ctx_seq,
2023920233
cparams.n_seq_max,
2024020234
1,
2024120235
hparams.n_swa,

tools/server/server.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2379,10 +2379,6 @@ struct server_context {
23792379
llama_batch_free(batch);
23802380
}
23812381

2382-
int32_t n_ctx_slot() const {
2383-
return params_base.kv_unified ? n_ctx : n_ctx / params_base.n_parallel;
2384-
}
2385-
23862382
bool load_model(const common_params & params) {
23872383
SRV_INF("loading model '%s'\n", params.model.path.c_str());
23882384

@@ -2411,7 +2407,7 @@ struct server_context {
24112407

24122408
params_dft.devices = params_base.speculative.devices;
24132409
params_dft.model = params_base.speculative.model;
2414-
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? n_ctx_slot() : params_base.speculative.n_ctx;
2410+
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
24152411
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
24162412
params_dft.n_parallel = 1;
24172413
params_dft.cache_type_k = params_base.speculative.cache_type_k;
@@ -2506,7 +2502,7 @@ struct server_context {
25062502

25072503
slot.id = i;
25082504
slot.ctx = ctx;
2509-
slot.n_ctx = n_ctx_slot();
2505+
slot.n_ctx = llama_n_ctx_seq(ctx);
25102506
slot.mctx = mctx;
25112507
slot.prompt.tokens.has_mtmd = mctx != nullptr;
25122508

tools/server/tests/unit/test_chat_completion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
433433
@pytest.mark.parametrize(
434434
"n_batch,batch_count,reuse_cache",
435435
[
436-
(64, 15, False),
436+
(64, 3, False),
437437
(64, 1, True),
438438
]
439439
)
440-
def test_return_progresssss(n_batch, batch_count, reuse_cache):
440+
def test_return_progress(n_batch, batch_count, reuse_cache):
441441
global server
442442
server.n_batch = n_batch
443-
server.n_ctx = 2048
443+
server.n_ctx = 256
444444
server.n_slots = 1
445445
server.start()
446446
def make_cmpl_request():
447447
return server.make_stream_request("POST", "/chat/completions", data={
448448
"max_tokens": 10,
449449
"messages": [
450-
{"role": "user", "content": "This is a test" * 100},
450+
{"role": "user", "content": "This is a test" * 10},
451451
],
452452
"stream": True,
453453
"return_progress": True,

tools/server/tests/unit/test_infill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_infill_without_input_extra():
1818
"input_suffix": "}\n",
1919
})
2020
assert res.status_code == 200
21-
assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
21+
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
2222

2323

2424
def test_infill_with_input_extra():
@@ -34,7 +34,7 @@ def test_infill_with_input_extra():
3434
"input_suffix": "}\n",
3535
})
3636
assert res.status_code == 200
37-
assert match_regex("(Dad|excited|park)+", res.body["content"])
37+
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
3838

3939

4040
@pytest.mark.parametrize("input_extra", [

0 commit comments

Comments
 (0)