Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3727,7 +3727,7 @@ struct server_context {
}
} else {
if (slot.n_prompt_tokens() >= slot.n_ctx) {
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
Expand Down Expand Up @@ -4955,9 +4955,17 @@ int main(int argc, char ** argv) {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}

const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking maybe this check can better be done inside launch_slot_with_task? There you will have access to slot.n_ctx

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately there is no way to return non-200 status code once you call res.set_chunked_content_provider(...). That's why I am doing the check before that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ok then, I'll refactor this code in #16488 , for now this can be a temporary soltuion

tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto n_prompt_tokens = inputs[i].size();
if (n_prompt_tokens >= n_ctx_slot) {
json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
error_data["n_prompt_tokens"] = n_prompt_tokens;
error_data["n_ctx"] = n_ctx_slot;
res_error(res, error_data);
Comment on lines +4963 to +4966
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is handled inside launch_slot_with_task, you can call send_error(slot, ".....", ERROR_TYPE_EXCEED_CONTEXT_SIZE);, which should simplify things a bit

return;
}
server_task task = server_task(type);

task.id = ctx_server.queue_tasks.get_new_id();
Expand Down
22 changes: 22 additions & 0 deletions tools/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,28 @@ def test_context_size_exceeded():
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots


def test_context_size_exceeded_stream():
global server
server.start()
try:
for _ in server.make_stream_request("POST", "/chat/completions", data={
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
] * 100, # make the prompt too long
"stream": True}):
pass
assert False, "Should have failed"
except ServerError as e:
assert e.code == 400
assert "error" in e.body
assert e.body["error"]["type"] == "exceed_context_size_error"
assert e.body["error"]["n_prompt_tokens"] > 0
assert server.n_ctx is not None
assert server.n_slots is not None
assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots


@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[
Expand Down
8 changes: 8 additions & 0 deletions tools/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ class ServerResponse:
body: dict | Any


class ServerError(Exception):
def __init__(self, code, body):
self.code = code
self.body = body


class ServerProcess:
# default options
debug: bool = False
Expand Down Expand Up @@ -297,6 +303,8 @@ def make_stream_request(
response = requests.post(url, headers=headers, json=data, stream=True)
else:
raise ValueError(f"Unimplemented method: {method}")
if response.status_code != 200:
raise ServerError(response.status_code, response.json())
for line_bytes in response.iter_lines():
line = line_bytes.decode("utf-8")
if '[DONE]' in line:
Expand Down
Loading