Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.ctx_shift = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg(
{"--chunks"}, "N",
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
Expand Down
27 changes: 26 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,15 @@ struct server_context {
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}

// if context shift is disabled, we stop when it reaches the context limit
if (slot.n_decoded >= slot.n_ctx) {
slot.truncated = true;
slot.stopped_limit = true;
slot.has_next_token = false;

SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
}

if (llama_token_is_eog(model, result.tok)) {
slot.stopped_eos = true;
slot.has_next_token = false;
Expand Down Expand Up @@ -1480,7 +1489,7 @@ struct server_context {
if (result.error) {
error_handler(result.data);
cancel_tasks(id_tasks);
break;
return;
}

size_t idx = result.data["index"];
Expand Down Expand Up @@ -1827,6 +1836,14 @@ struct server_context {
for (server_slot & slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
if (!params.ctx_shift) {
// this check is redundant (for good)
// we should never get here, since n_predict is already limited
slot.release();
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
continue;
}

// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
Expand Down Expand Up @@ -1961,6 +1978,14 @@ struct server_context {
continue;
}
} else {
if (!params.ctx_shift) {
// if context shift is disabled, we make sure prompt size is smaller than KV size
if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
slot.release();
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
continue;
}
}
if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens;
}
Expand Down
62 changes: 62 additions & 0 deletions examples/server/tests/features/ctx_shift.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
@llama.cpp
@ctx_shift
Feature: llama.cpp server

Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file test-model.gguf
And a model alias tinyllama-2
And BOS token is 1
And 42 as server seed
And 256 KV cache size
And 32 as batch size
And 2 slots

Scenario: Inference with context shift
And 64 server max tokens to predict
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with no api error
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
And the completion is truncated
And 109 prompt tokens are processed

Scenario Outline: Inference without context shift
And <n_predict> server max tokens to predict
And disable context shifting
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Hi how are you
"""
And a completion request with no api error
Then <n_token_output> tokens are predicted matching twind|Anna
And the completion is <truncated> truncated
And 8 prompt tokens are processed
Examples:
| n_predict | n_token_output | truncated |
| 64 | 64 | not |
| -1 | 120 | |

Scenario: Inference without context shift (expected error: prompt too long)
And disable context shifting
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with 400 api error

20 changes: 17 additions & 3 deletions examples/server/tests/features/embeddings.feature
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ Feature: llama.cpp server
# the bert-bge-small model has context size of 512
# since the generated prompts are as big as the batch size, we need to set the batch size to 512
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
And 512 as batch size
And 512 as ubatch size
And 2048 KV cache size
And 128 as batch size
And 128 as ubatch size
And 512 KV cache size
And embeddings extraction
Then the server is starting
Then the server is healthy
Expand All @@ -26,6 +26,20 @@ Feature: llama.cpp server
"""
Then embeddings are generated

Scenario: Embedding (error: prompt too long)
When embeddings are computed for:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And embeddings request with 500 api error

Scenario: OAI Embeddings compatibility
Given a model bert-bge-small
When an OAI compatible embeddings computation request for:
Expand Down
28 changes: 21 additions & 7 deletions examples/server/tests/features/steps/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.response_format = None
context.temperature = None
context.lora_file = None
context.disable_ctx_shift = False

context.tasks_result = []
context.concurrent_tasks = []
Expand Down Expand Up @@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int):

@step('{n_predict:d} server max tokens to predict')
def step_server_n_predict(context, n_predict: int):
context.n_server_predict = n_predict
context.n_server_predict = n_predict if n_predict > 0 else None


@step('{slot_save_path} as slot save path')
Expand Down Expand Up @@ -180,6 +181,9 @@ def step_server_embeddings(context):
def step_server_metrics(context):
context.server_metrics = True

@step('disable context shifting')
def step_server_metrics(context):
context.disable_ctx_shift = True

@step("the server is starting")
def step_start_server(context):
Expand Down Expand Up @@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
@step('a completion request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error: Literal['raised'] | str):
expect_api_error = api_error == 'raised'
expect_api_error = api_error == 'raised' or api_error != 'no'
seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(),
seeds[0] if seeds is not None else seeds,
Expand All @@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
context.tasks_result.append(completion)
if context.debug:
print(f"Completion response: {completion}")
if expect_api_error:
if api_error == 'raised':
assert completion == 401, f"completion must be an 401 status code: {completion}"
elif api_error.isdigit():
api_error_code = int(api_error)
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"


@step('{predicted_n:d} tokens are predicted matching {re_content}')
Expand Down Expand Up @@ -645,6 +652,9 @@ def step_assert_embeddings(context):
for embedding in context.embeddings:
assert_embeddings(embedding)

@step('embeddings request with {api_error_code:d} api error')
def step_assert_embeddings(context, api_error_code: int):
assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"

@step('an OAI compatible embeddings computation request for')
@async_run_until_complete
Expand Down Expand Up @@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt,
return completion_response


async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
async with session.post(f'{base_url}/embedding',
json={
"content": content,
}) as response:
assert response.status == 200
response_json = await response.json()
return [response_json['embedding']]
if response.status == 200:
response_json = await response.json()
return [response_json['embedding']]
else:
return response.status


async def request_oai_embeddings(input, seed,
Expand Down Expand Up @@ -1372,6 +1384,8 @@ def start_server_background(context):
server_args.append('--verbose')
if context.lora_file:
server_args.extend(['--lora', context.lora_file])
if context.disable_ctx_shift:
server_args.extend(['--no-context-shift'])

args = [str(arg) for arg in [context.server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
Expand Down
Loading