Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4642,9 +4642,17 @@ int main(int argc, char ** argv) {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}

const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto n_prompt_tokens = inputs[i].size();
if (!ctx_server.params_base.ctx_shift && n_prompt_tokens >= n_ctx_slot) {
json error_data = format_error_response("the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
Comment on lines +4649 to +4650
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prompt truncation functionality is being removed in #16391:

image

So no longer need to check ctx_shift here and respectively no need to suggest enabling it in the error.

error_data["n_prompt_tokens"] = n_prompt_tokens;
error_data["n_ctx"] = n_ctx_slot;
res_error(res, error_data);
return;
}
server_task task = server_task(type);

task.id = ctx_server.queue_tasks.get_new_id();
Expand Down
22 changes: 22 additions & 0 deletions tools/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,28 @@ def test_context_size_exceeded():
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots


def test_context_size_exceeded_stream():
global server
server.start()
try:
for _ in server.make_stream_request("POST", "/chat/completions", data={
"messages": [
{"role": "user", "content": "this is a very long prompt " * 1000},
],
"stream": True,
}):
pass
assert False, "Should have failed"
except ServerError as e:
assert e.code == 400
assert "error" in e.body
assert e.body["error"]["type"] == "exceed_context_size_error"
assert e.body["error"]["n_prompt_tokens"] > 0
assert server.n_ctx is not None
assert server.n_slots is not None
assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots


@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[
Expand Down
8 changes: 8 additions & 0 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ class ServerResponse:
body: dict | Any


class ServerError(Exception):
def __init__(self, code, body):
self.code = code
self.body = body


class ServerProcess:
# default options
debug: bool = False
Expand Down Expand Up @@ -297,6 +303,8 @@ def make_stream_request(
response = requests.post(url, headers=headers, json=data, stream=True)
else:
raise ValueError(f"Unimplemented method: {method}")
if response.status_code != 200:
raise ServerError(response.status_code, response.json())
for line_bytes in response.iter_lines():
line = line_bytes.decode("utf-8")
if '[DONE]' in line:
Expand Down
Loading