diff --git a/common/arg.cpp b/common/arg.cpp index 98baac4c14da2..d3868018ef203 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1530,6 +1530,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.ctx_shift = false; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT")); + add_opt(common_arg( + {"--context-shift"}, + string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), + [](common_params & params) { + params.ctx_shift = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT")); add_opt(common_arg( {"--chunks"}, "N", string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), diff --git a/common/common.h b/common/common.h index 75596e6b32979..9376a3115c30e 100644 --- a/common/common.h +++ b/common/common.h @@ -372,7 +372,7 @@ struct common_params { bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention bool no_perf = false; // disable performance metrics - bool ctx_shift = true; // context shift on inifinite text generation + bool ctx_shift = false; // context shift on inifinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 1485de8ceb3fc..c7b3af0489164 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -5,7 +5,7 @@ server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index be3a0052c64fe..adb6f27864ef9 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -7,7 +7,7 @@ server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @@ -229,7 +229,7 @@ def test_nocache_long_input_prompt(): "temperature": 1.0, "cache_prompt": False, }) - assert res.status_code == 200 + assert res.status_code == 400 def test_completion_with_tokens_input(): diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 2431ac70882d7..8f51bc301a74c 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -11,7 +11,7 @@ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. """.strip() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @@ -25,6 +25,7 @@ def test_ctx_shift_enabled(): # the prompt is truncated to keep the last 109 tokens # 64 tokens are generated thanks to shifting the context when it gets full global server + server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 64, @@ -42,7 +43,6 @@ def test_ctx_shift_enabled(): ]) def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): global server - server.disable_ctx_shift = True server.n_predict = -1 server.start() res = server.make_request("POST", "/completion", data={ @@ -56,7 +56,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr def test_ctx_shift_disabled_long_prompt(): global server - server.disable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 64, @@ -68,7 +67,6 @@ def test_ctx_shift_disabled_long_prompt(): def test_ctx_shift_disabled_stream(): global server - server.disable_ctx_shift = True server.start() res = server.make_stream_request("POST", "/v1/completions", data={ "n_predict": 256, diff --git a/tools/server/tests/unit/test_embedding.py b/tools/server/tests/unit/test_embedding.py index 0feb452ccfcd4..50601b839650f 100644 --- a/tools/server/tests/unit/test_embedding.py +++ b/tools/server/tests/unit/test_embedding.py @@ -8,7 +8,7 @@ EPSILON = 1e-3 -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.bert_bge_small() diff --git a/tools/server/tests/unit/test_infill.py b/tools/server/tests/unit/test_infill.py index 10554db0f623e..73dacdae812b8 100644 --- a/tools/server/tests/unit/test_infill.py +++ b/tools/server/tests/unit/test_infill.py @@ -3,7 +3,7 @@ server = ServerPreset.tinyllama_infill() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama_infill() diff --git a/tools/server/tests/unit/test_lora.py b/tools/server/tests/unit/test_lora.py index c1aa8be70e2f7..00b2f245f60fc 100644 --- a/tools/server/tests/unit/test_lora.py +++ b/tools/server/tests/unit/test_lora.py @@ -5,7 +5,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf" -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.stories15m_moe() diff --git a/tools/server/tests/unit/test_rerank.py b/tools/server/tests/unit/test_rerank.py index f4f570ad5ef78..0b63c7821eb98 100644 --- a/tools/server/tests/unit/test_rerank.py +++ b/tools/server/tests/unit/test_rerank.py @@ -4,7 +4,7 @@ server = ServerPreset.jina_reranker_tiny() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.jina_reranker_tiny() diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index 620b25376bd81..0e11580553aa6 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -6,7 +6,7 @@ TEST_API_KEY = "sk-this-is-the-secret-key" -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_slot_save.py b/tools/server/tests/unit/test_slot_save.py index 38704f5ece35a..1b428cc2a840c 100644 --- a/tools/server/tests/unit/test_slot_save.py +++ b/tools/server/tests/unit/test_slot_save.py @@ -3,7 +3,7 @@ server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py index 54db38cf3bd80..38ca4325ba675 100644 --- a/tools/server/tests/unit/test_speculative.py +++ b/tools/server/tests/unit/test_speculative.py @@ -16,7 +16,7 @@ def create_server(): server.draft_max = 8 -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def fixture_create_server(): return create_server() @@ -91,6 +91,7 @@ def test_slot_ctx_not_exceeded(): def test_with_ctx_shift(): global server server.n_ctx = 64 + server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "prompt": "Hello " * 56, diff --git a/tools/server/tests/unit/test_tokenize.py b/tools/server/tests/unit/test_tokenize.py index 382457c9d602f..424cac5f3d394 100644 --- a/tools/server/tests/unit/test_tokenize.py +++ b/tools/server/tests/unit/test_tokenize.py @@ -4,7 +4,7 @@ server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index 20f048c6f6aa5..a3c3ccdf586ab 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -22,6 +22,8 @@ def create_server(): server.model_alias = "tinyllama-2-tool-call" server.server_port = 8081 server.n_slots = 1 + server.n_ctx = 8192 + server.n_batch = 2048 class CompletionMode(Enum): NORMAL = "normal" diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index bc547ca03bf1b..49277e6000236 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -79,7 +79,7 @@ class ServerProcess: draft: int | None = None api_key: str | None = None lora_files: List[str] | None = None - disable_ctx_shift: int | None = False + enable_ctx_shift: int | None = False draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None @@ -178,8 +178,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: if self.lora_files: for lora_file in self.lora_files: server_args.extend(["--lora", lora_file]) - if self.disable_ctx_shift: - server_args.extend(["--no-context-shift"]) + if self.enable_ctx_shift: + server_args.append("--context-shift") if self.api_key: server_args.extend(["--api-key", self.api_key]) if self.draft_max: diff --git a/tools/tts/tts.cpp b/tools/tts/tts.cpp index a71e9bf5b589e..18f01a9946350 100644 --- a/tools/tts/tts.cpp +++ b/tools/tts/tts.cpp @@ -581,7 +581,6 @@ int main(int argc, char ** argv) { params.model = params.vocoder.model; params.embedding = true; - params.ctx_shift = false; // silence warning params.n_ubatch = params.n_batch; common_init_result llama_init_cts = common_init_from_params(params);