Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.ctx_shift = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
add_opt(common_arg(
{"--context-shift"},
string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
[](common_params & params) {
params.ctx_shift = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT"));
add_opt(common_arg(
{"--chunks"}, "N",
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
Expand Down
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ struct common_params {
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool ctx_shift = false; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache

Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_nocache_long_input_prompt():
"temperature": 1.0,
"cache_prompt": False,
})
assert res.status_code == 200
assert res.status_code == 400


def test_completion_with_tokens_input():
Expand Down
5 changes: 2 additions & 3 deletions tools/server/tests/unit/test_ctx_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def test_ctx_shift_enabled():
# the prompt is truncated to keep the last 109 tokens
# 64 tokens are generated thanks to shifting the context when it gets full
global server
server.enable_ctx_shift = True
server.start()
server.enable_ctx_shift = False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson I noticed that the server parameters are stateful - i.e. if we change a parameter in one test, it will remain changed for the rest of the tests. This is the reason I do it like this here.

Is there a better way to set the parameter just for the scope of the current test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be possible that the scope=module is the problem. Could you try removing it? (While keeping auto_use)

I was a bit confused about the notion of scope in pytest

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - this seems to work.

res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
Expand All @@ -42,7 +44,6 @@ def test_ctx_shift_enabled():
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server
server.disable_ctx_shift = True
server.n_predict = -1
server.start()
res = server.make_request("POST", "/completion", data={
Expand All @@ -56,7 +57,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr

def test_ctx_shift_disabled_long_prompt():
global server
server.disable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
Expand All @@ -68,7 +68,6 @@ def test_ctx_shift_disabled_long_prompt():

def test_ctx_shift_disabled_stream():
global server
server.disable_ctx_shift = True
server.start()
res = server.make_stream_request("POST", "/v1/completions", data={
"n_predict": 256,
Expand Down
2 changes: 2 additions & 0 deletions tools/server/tests/unit/test_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def test_slot_ctx_not_exceeded():
def test_with_ctx_shift():
global server
server.n_ctx = 64
server.enable_ctx_shift = True
server.start()
server.enable_ctx_shift = False
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
"temperature": 0.0,
Expand Down
2 changes: 2 additions & 0 deletions tools/server/tests/unit/test_tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def create_server():
server.model_alias = "tinyllama-2-tool-call"
server.server_port = 8081
server.n_slots = 1
server.n_ctx = 8192
server.n_batch = 2048

class CompletionMode(Enum):
NORMAL = "normal"
Expand Down
6 changes: 3 additions & 3 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class ServerProcess:
draft: int | None = None
api_key: str | None = None
lora_files: List[str] | None = None
disable_ctx_shift: int | None = False
enable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
no_webui: bool | None = None
Expand Down Expand Up @@ -178,8 +178,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
if self.lora_files:
for lora_file in self.lora_files:
server_args.extend(["--lora", lora_file])
if self.disable_ctx_shift:
server_args.extend(["--no-context-shift"])
if self.enable_ctx_shift:
server_args.append("--context-shift")
if self.api_key:
server_args.extend(["--api-key", self.api_key])
if self.draft_max:
Expand Down
1 change: 0 additions & 1 deletion tools/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ int main(int argc, char ** argv) {

params.model = params.vocoder.model;
params.embedding = true;
params.ctx_shift = false; // silence warning
params.n_ubatch = params.n_batch;

common_init_result llama_init_cts = common_init_from_params(params);
Expand Down
Loading