Skip to content

Commit 6777332

Browse files
committed
Respect n_predict=-2 in server
1 parent 2c9f833 commit 6777332

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

examples/server/server.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,17 +1321,24 @@ struct server_slot {
13211321
&& are_lora_equal(lora, other_slot.lora);
13221322
}
13231323

1324+
// There are two caps on the budge of a single request:
1325+
// * [params.n_predict]
1326+
// * [global_params.n_predict]
1327+
// This function returns true if the request is not limited by either of them.
13241328
bool has_budget(const common_params & global_params) {
13251329
if (params.n_predict == -1 && global_params.n_predict == -1) {
13261330
return true; // limitless
13271331
}
1332+
n_remaining = INT32_MAX;
13281333

1329-
n_remaining = -1;
1334+
// The request or server have finite limits on the number of tokens to generate.
1335+
if ((params.n_predict != -1 && params.n_predict != -2) || (global_params.n_predict != -1 && global_params.n_predict != -2)) {
1336+
n_remaining = std::min(n_remaining, params.n_predict - n_decoded);
1337+
}
13301338

1331-
if (params.n_predict != -1) {
1332-
n_remaining = params.n_predict - n_decoded;
1333-
} else if (global_params.n_predict != -1) {
1334-
n_remaining = global_params.n_predict - n_decoded;
1339+
// The request or server have limits based on the context window.
1340+
if (params.n_predict == -2 || global_params.n_predict == -2) {
1341+
n_remaining = std::min(n_remaining, n_ctx - n_decoded);
13351342
}
13361343

13371344
return n_remaining > 0; // no budget

test.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
curl --location 'http://localhost:8080/v1/chat/completions' \
2+
--header 'Content-Type: application/json' \
3+
--header 'Authorization: Bearer no-key' \
4+
--data '{
5+
"messages": [
6+
{
7+
"role": "user",
8+
"content": "Count from 1 to 4097 one at a time, separating each number with a newline. You should not abbreviate the numbers, but list out every single one."
9+
}
10+
],
11+
"n_predict": -2
12+
}'

0 commit comments

Comments
 (0)