Skip to content

Commit aac559d

Browse files
committed
server : return HTTP 400 if prompt exceeds context length
In streaming mode when prompt exceeds context length, the server returns HTTP 200 status code with a JSON error in the body. This is very confusing and inconsistent with all other inference engines which return HTTP 4xx error in this case. This patch fixes this problem and makes the server return HTTP 400 in such cases.
1 parent 56b4795 commit aac559d

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

tools/server/server.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4642,9 +4642,17 @@ int main(int argc, char ** argv) {
46424642
// Everything else, including multimodal completions.
46434643
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
46444644
}
4645-
4645+
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
46464646
tasks.reserve(inputs.size());
46474647
for (size_t i = 0; i < inputs.size(); i++) {
4648+
auto n_prompt_tokens = inputs[i].size();
4649+
if (!ctx_server.params_base.ctx_shift && n_prompt_tokens >= n_ctx_slot) {
4650+
json error_data = format_error_response("the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
4651+
error_data["n_prompt_tokens"] = n_prompt_tokens;
4652+
error_data["n_ctx"] = n_ctx_slot;
4653+
res_error(res, error_data);
4654+
return;
4655+
}
46484656
server_task task = server_task(type);
46494657

46504658
task.id = ctx_server.queue_tasks.get_new_id();

tools/server/tests/unit/test_chat_completion.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,28 @@ def test_context_size_exceeded():
408408
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
409409

410410

411+
def test_context_size_exceeded_stream():
412+
global server
413+
server.start()
414+
try:
415+
for _ in server.make_stream_request("POST", "/chat/completions", data={
416+
"messages": [
417+
{"role": "user", "content": "this is a very long prompt " * 1000},
418+
],
419+
"stream": True,
420+
}):
421+
pass
422+
assert False, "Should have failed"
423+
except ServerError as e:
424+
assert e.code == 400
425+
assert "error" in e.body
426+
assert e.body["error"]["type"] == "exceed_context_size_error"
427+
assert e.body["error"]["n_prompt_tokens"] > 0
428+
assert server.n_ctx is not None
429+
assert server.n_slots is not None
430+
assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
431+
432+
411433
@pytest.mark.parametrize(
412434
"n_batch,batch_count,reuse_cache",
413435
[

tools/server/tests/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ class ServerResponse:
3535
body: dict | Any
3636

3737

38+
class ServerError(Exception):
39+
def __init__(self, code, body):
40+
self.code = code
41+
self.body = body
42+
43+
3844
class ServerProcess:
3945
# default options
4046
debug: bool = False
@@ -297,6 +303,8 @@ def make_stream_request(
297303
response = requests.post(url, headers=headers, json=data, stream=True)
298304
else:
299305
raise ValueError(f"Unimplemented method: {method}")
306+
if response.status_code != 200:
307+
raise ServerError(response.status_code, response.json())
300308
for line_bytes in response.iter_lines():
301309
line = line_bytes.decode("utf-8")
302310
if '[DONE]' in line:

0 commit comments

Comments
 (0)