Skip to content

Commit 55bb9db

Browse files
committed
llama : update per-seq context computation
1 parent 49f7a3c commit 55bb9db

File tree

8 files changed

+40
-36
lines changed

8 files changed

+40
-36
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ extern "C" {
461461
LLAMA_API bool llama_supports_rpc (void);
462462

463463
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
464+
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
464465
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
465466
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
466467
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);

src/llama-context.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,17 @@ llama_context::llama_context(
112112
}
113113
}
114114

115+
cparams.n_ctx_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;
116+
117+
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
118+
LLAMA_LOG_WARN("%s: capping n_ctx_seq (%u) to n_ctx_train (%u)\n", __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
119+
120+
cparams.n_ctx_seq = hparams.n_ctx_train;
121+
}
122+
115123
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
116124
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
117-
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq());
125+
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
118126
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
119127
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
120128
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
@@ -123,14 +131,14 @@ llama_context::llama_context(
123131
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
124132
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
125133

126-
if (n_ctx_per_seq() < hparams.n_ctx_train) {
127-
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
128-
__func__, n_ctx_per_seq(), hparams.n_ctx_train);
134+
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
135+
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
136+
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
129137
}
130138

131-
if (n_ctx_per_seq() > hparams.n_ctx_train) {
132-
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
133-
__func__, n_ctx_per_seq(), hparams.n_ctx_train);
139+
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
140+
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
141+
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
134142
}
135143

136144
if (!hparams.vocab_only) {
@@ -451,8 +459,8 @@ uint32_t llama_context::n_ctx() const {
451459
return cparams.n_ctx;
452460
}
453461

454-
uint32_t llama_context::n_ctx_per_seq() const {
455-
return cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;
462+
uint32_t llama_context::n_ctx_seq() const {
463+
return cparams.n_ctx_seq;
456464
}
457465

458466
uint32_t llama_context::n_batch() const {
@@ -2381,6 +2389,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
23812389
return ctx->n_ctx();
23822390
}
23832391

2392+
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
2393+
return ctx->n_ctx_seq();
2394+
}
2395+
23842396
uint32_t llama_n_batch(const llama_context * ctx) {
23852397
return ctx->n_batch();
23862398
}

src/llama-context.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ struct llama_context {
4343

4444
ggml_backend_sched_t get_sched() const;
4545

46-
uint32_t n_ctx() const;
47-
uint32_t n_ctx_per_seq() const;
48-
uint32_t n_batch() const;
49-
uint32_t n_ubatch() const;
50-
uint32_t n_seq_max() const;
46+
uint32_t n_ctx() const;
47+
uint32_t n_ctx_seq() const;
48+
uint32_t n_batch() const;
49+
uint32_t n_ubatch() const;
50+
uint32_t n_seq_max() const;
5151

5252
uint32_t n_threads() const;
5353
uint32_t n_threads_batch() const;

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
struct llama_cparams {
1010
uint32_t n_ctx; // context size used during inference
11+
uint32_t n_ctx_seq; // context for a single sequence
1112
uint32_t n_batch;
1213
uint32_t n_ubatch;
1314
uint32_t n_seq_max;

src/llama-model.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6581,14 +6581,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
65816581
}
65826582

65836583
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
6584-
const uint32_t n_ctx_per_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;
6584+
const uint32_t n_ctx_seq = cparams.n_ctx_seq;
65856585

65866586
// choose long/short freq factors based on the context size
65876587
if (layers[il].rope_freqs != nullptr) {
65886588
return layers[il].rope_freqs;
65896589
}
65906590

6591-
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
6591+
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
65926592
return layers[il].rope_long;
65936593
}
65946594

@@ -19710,12 +19710,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1971019710
/* filter_attn */ std::move(filter_attn),
1971119711
/* filter_recr */ std::move(filter_recr));
1971219712
} else {
19713-
uint32_t n_ctx_per_stream = cparams.n_ctx;
19714-
19715-
if (!cparams.kv_unified) {
19716-
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
19717-
}
19718-
1971919713
llama_memory_i::layer_reuse_cb reuse = nullptr;
1972019714

1972119715
if (arch == LLM_ARCH_GEMMA3N) {
@@ -19739,7 +19733,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1973919733
cparams.offload_kqv,
1974019734
params.swa_full,
1974119735
cparams.kv_unified,
19742-
n_ctx_per_stream,
19736+
cparams.n_ctx_seq,
1974319737
cparams.n_seq_max,
1974419738
cparams.n_ubatch,
1974519739
1,
@@ -19755,7 +19749,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1975519749
!cparams.flash_attn,
1975619750
cparams.offload_kqv,
1975719751
cparams.kv_unified,
19758-
n_ctx_per_stream,
19752+
cparams.n_ctx_seq,
1975919753
cparams.n_seq_max,
1976019754
1,
1976119755
hparams.n_swa,

tools/server/server.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,10 +2385,6 @@ struct server_context {
23852385
llama_batch_free(batch);
23862386
}
23872387

2388-
int32_t n_ctx_slot() const {
2389-
return params_base.kv_unified ? n_ctx : n_ctx / params_base.n_parallel;
2390-
}
2391-
23922388
bool load_model(const common_params & params) {
23932389
SRV_INF("loading model '%s'\n", params.model.path.c_str());
23942390

@@ -2417,7 +2413,7 @@ struct server_context {
24172413

24182414
params_dft.devices = params_base.speculative.devices;
24192415
params_dft.model = params_base.speculative.model;
2420-
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? n_ctx_slot() : params_base.speculative.n_ctx;
2416+
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
24212417
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
24222418
params_dft.n_parallel = 1;
24232419
params_dft.cache_type_k = params_base.speculative.cache_type_k;
@@ -2512,7 +2508,7 @@ struct server_context {
25122508

25132509
slot.id = i;
25142510
slot.ctx = ctx;
2515-
slot.n_ctx = n_ctx_slot();
2511+
slot.n_ctx = llama_n_ctx_seq(ctx);
25162512
slot.mctx = mctx;
25172513
slot.prompt.tokens.has_mtmd = mctx != nullptr;
25182514

tools/server/tests/unit/test_chat_completion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
433433
@pytest.mark.parametrize(
434434
"n_batch,batch_count,reuse_cache",
435435
[
436-
(64, 15, False),
436+
(64, 3, False),
437437
(64, 1, True),
438438
]
439439
)
440-
def test_return_progresssss(n_batch, batch_count, reuse_cache):
440+
def test_return_progress(n_batch, batch_count, reuse_cache):
441441
global server
442442
server.n_batch = n_batch
443-
server.n_ctx = 2048
443+
server.n_ctx = 256
444444
server.n_slots = 1
445445
server.start()
446446
def make_cmpl_request():
447447
return server.make_stream_request("POST", "/chat/completions", data={
448448
"max_tokens": 10,
449449
"messages": [
450-
{"role": "user", "content": "This is a test" * 100},
450+
{"role": "user", "content": "This is a test" * 10},
451451
],
452452
"stream": True,
453453
"return_progress": True,

tools/server/tests/unit/test_infill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_infill_without_input_extra():
1818
"input_suffix": "}\n",
1919
})
2020
assert res.status_code == 200
21-
assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
21+
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
2222

2323

2424
def test_infill_with_input_extra():
@@ -34,7 +34,7 @@ def test_infill_with_input_extra():
3434
"input_suffix": "}\n",
3535
})
3636
assert res.status_code == 200
37-
assert match_regex("(Dad|excited|park)+", res.body["content"])
37+
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
3838

3939

4040
@pytest.mark.parametrize("input_extra", [

0 commit comments

Comments
 (0)