Skip to content

Commit 0876d42

Browse files
committed
server : disable context shift by default
ggml-ci
1 parent 6d7f111 commit 0876d42

File tree

8 files changed

+18
-9
lines changed

8 files changed

+18
-9
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,6 +1530,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15301530
params.ctx_shift = false;
15311531
}
15321532
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
1533+
add_opt(common_arg(
1534+
{"--context-shift"},
1535+
string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
1536+
[](common_params & params) {
1537+
params.ctx_shift = true;
1538+
}
1539+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT"));
15331540
add_opt(common_arg(
15341541
{"--chunks"}, "N",
15351542
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ struct common_params {
372372
bool cont_batching = true; // insert new sequences for decoding on-the-fly
373373
bool flash_attn = false; // flash attention
374374
bool no_perf = false; // disable performance metrics
375-
bool ctx_shift = true; // context shift on inifinite text generation
375+
bool ctx_shift = false; // context shift on inifinite text generation
376376
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
377377
bool kv_unified = false; // enable unified KV cache
378378

tools/server/tests/unit/test_completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def test_nocache_long_input_prompt():
229229
"temperature": 1.0,
230230
"cache_prompt": False,
231231
})
232-
assert res.status_code == 200
232+
assert res.status_code == 400
233233

234234

235235
def test_completion_with_tokens_input():

tools/server/tests/unit/test_ctx_shift.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def test_ctx_shift_enabled():
2525
# the prompt is truncated to keep the last 109 tokens
2626
# 64 tokens are generated thanks to shifting the context when it gets full
2727
global server
28+
server.enable_ctx_shift = True
2829
server.start()
30+
server.enable_ctx_shift = False
2931
res = server.make_request("POST", "/completion", data={
3032
"n_predict": 64,
3133
"prompt": LONG_TEXT,
@@ -42,7 +44,6 @@ def test_ctx_shift_enabled():
4244
])
4345
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
4446
global server
45-
server.disable_ctx_shift = True
4647
server.n_predict = -1
4748
server.start()
4849
res = server.make_request("POST", "/completion", data={
@@ -56,7 +57,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr
5657

5758
def test_ctx_shift_disabled_long_prompt():
5859
global server
59-
server.disable_ctx_shift = True
6060
server.start()
6161
res = server.make_request("POST", "/completion", data={
6262
"n_predict": 64,
@@ -68,7 +68,6 @@ def test_ctx_shift_disabled_long_prompt():
6868

6969
def test_ctx_shift_disabled_stream():
7070
global server
71-
server.disable_ctx_shift = True
7271
server.start()
7372
res = server.make_stream_request("POST", "/v1/completions", data={
7473
"n_predict": 256,

tools/server/tests/unit/test_speculative.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ def test_slot_ctx_not_exceeded():
9191
def test_with_ctx_shift():
9292
global server
9393
server.n_ctx = 64
94+
server.enable_ctx_shift = True
9495
server.start()
96+
server.enable_ctx_shift = False
9597
res = server.make_request("POST", "/completion", data={
9698
"prompt": "Hello " * 56,
9799
"temperature": 0.0,

tools/server/tests/unit/test_tool_call.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def create_server():
2222
server.model_alias = "tinyllama-2-tool-call"
2323
server.server_port = 8081
2424
server.n_slots = 1
25+
server.n_ctx = 8192
26+
server.n_batch = 2048
2527

2628
class CompletionMode(Enum):
2729
NORMAL = "normal"

tools/server/tests/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class ServerProcess:
7979
draft: int | None = None
8080
api_key: str | None = None
8181
lora_files: List[str] | None = None
82-
disable_ctx_shift: int | None = False
82+
enable_ctx_shift: int | None = False
8383
draft_min: int | None = None
8484
draft_max: int | None = None
8585
no_webui: bool | None = None
@@ -178,8 +178,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
178178
if self.lora_files:
179179
for lora_file in self.lora_files:
180180
server_args.extend(["--lora", lora_file])
181-
if self.disable_ctx_shift:
182-
server_args.extend(["--no-context-shift"])
181+
if self.enable_ctx_shift:
182+
server_args.append("--context-shift")
183183
if self.api_key:
184184
server_args.extend(["--api-key", self.api_key])
185185
if self.draft_max:

tools/tts/tts.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,6 @@ int main(int argc, char ** argv) {
581581

582582
params.model = params.vocoder.model;
583583
params.embedding = true;
584-
params.ctx_shift = false; // silence warning
585584
params.n_ubatch = params.n_batch;
586585

587586
common_init_result llama_init_cts = common_init_from_params(params);

0 commit comments

Comments
 (0)