Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.ctx_shift = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
add_opt(common_arg(
{"--context-shift"},
string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
[](common_params & params) {
params.ctx_shift = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT"));
add_opt(common_arg(
{"--chunks"}, "N",
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
Expand Down
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ struct common_params {
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool ctx_shift = false; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache

Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
server = ServerPreset.tinyllama2()


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
Expand Down
4 changes: 2 additions & 2 deletions tools/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
server = ServerPreset.tinyllama2()


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_nocache_long_input_prompt():
"temperature": 1.0,
"cache_prompt": False,
})
assert res.status_code == 200
assert res.status_code == 400


def test_completion_with_tokens_input():
Expand Down
6 changes: 2 additions & 4 deletions tools/server/tests/unit/test_ctx_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
""".strip()

@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
Expand All @@ -25,6 +25,7 @@ def test_ctx_shift_enabled():
# the prompt is truncated to keep the last 109 tokens
# 64 tokens are generated thanks to shifting the context when it gets full
global server
server.enable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
Expand All @@ -42,7 +43,6 @@ def test_ctx_shift_enabled():
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server
server.disable_ctx_shift = True
server.n_predict = -1
server.start()
res = server.make_request("POST", "/completion", data={
Expand All @@ -56,7 +56,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr

def test_ctx_shift_disabled_long_prompt():
global server
server.disable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
Expand All @@ -68,7 +67,6 @@ def test_ctx_shift_disabled_long_prompt():

def test_ctx_shift_disabled_stream():
global server
server.disable_ctx_shift = True
server.start()
res = server.make_stream_request("POST", "/v1/completions", data={
"n_predict": 256,
Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

EPSILON = 1e-3

@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.bert_bge_small()
Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_infill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

server = ServerPreset.tinyllama_infill()

@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama_infill()
Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"

@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.stories15m_moe()
Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
server = ServerPreset.jina_reranker_tiny()


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.jina_reranker_tiny()
Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

TEST_API_KEY = "sk-this-is-the-secret-key"

@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_slot_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

server = ServerPreset.tinyllama2()

@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
Expand Down
3 changes: 2 additions & 1 deletion tools/server/tests/unit/test_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def create_server():
server.draft_max = 8


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def fixture_create_server():
return create_server()

Expand Down Expand Up @@ -91,6 +91,7 @@ def test_slot_ctx_not_exceeded():
def test_with_ctx_shift():
global server
server.n_ctx = 64
server.enable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
Expand Down
2 changes: 1 addition & 1 deletion tools/server/tests/unit/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
server = ServerPreset.tinyllama2()


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
Expand Down
2 changes: 2 additions & 0 deletions tools/server/tests/unit/test_tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def create_server():
server.model_alias = "tinyllama-2-tool-call"
server.server_port = 8081
server.n_slots = 1
server.n_ctx = 8192
server.n_batch = 2048

class CompletionMode(Enum):
NORMAL = "normal"
Expand Down
6 changes: 3 additions & 3 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class ServerProcess:
draft: int | None = None
api_key: str | None = None
lora_files: List[str] | None = None
disable_ctx_shift: int | None = False
enable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
no_webui: bool | None = None
Expand Down Expand Up @@ -178,8 +178,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
if self.lora_files:
for lora_file in self.lora_files:
server_args.extend(["--lora", lora_file])
if self.disable_ctx_shift:
server_args.extend(["--no-context-shift"])
if self.enable_ctx_shift:
server_args.append("--context-shift")
if self.api_key:
server_args.extend(["--api-key", self.api_key])
if self.draft_max:
Expand Down
1 change: 0 additions & 1 deletion tools/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ int main(int argc, char ** argv) {

params.model = params.vocoder.model;
params.embedding = true;
params.ctx_shift = false; // silence warning
params.n_ubatch = params.n_batch;

common_init_result llama_init_cts = common_init_from_params(params);
Expand Down
Loading