Skip to content

Commit d2419b3

Browse files
committed
many fixes
1 parent 0d6485f commit d2419b3

File tree

4 files changed

+45
-24
lines changed

4 files changed

+45
-24
lines changed

examples/server/server.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,8 @@ struct server_context {
11721172
res.n_decoded = slot.n_decoded;
11731173
res.n_prompt_tokens = slot.n_prompt_tokens;
11741174
res.content = tkn.text_to_send;
1175+
res.stop = slot.stop;
1176+
res.truncated = slot.truncated;
11751177

11761178
if (slot.params.sampling.n_probs > 0) {
11771179
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
@@ -1186,7 +1188,8 @@ struct server_context {
11861188
}
11871189
}
11881190

1189-
if (slot.params.timings_per_token) {
1191+
// populate timings if this is final response or timings_per_token is enabled
1192+
if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
11901193
res.timings = slot.get_timings();
11911194
}
11921195

@@ -1195,6 +1198,7 @@ struct server_context {
11951198

11961199
void send_final_response(server_slot & slot) {
11971200
if (slot.params.stream) {
1201+
// if in stream mode, send the last partial response
11981202
return send_partial_response(slot, {0, "", {}});
11991203
}
12001204

@@ -1209,6 +1213,8 @@ struct server_context {
12091213
res.n_tokens_cached = slot.n_past;
12101214
res.content = slot.generated_text;
12111215
res.stop = slot.stop;
1216+
res.truncated = slot.truncated;
1217+
res.timings = slot.get_timings();
12121218

12131219
res.generation_params = slot.params; // copy the parameters
12141220

@@ -1439,6 +1445,8 @@ struct server_context {
14391445
break;
14401446
}
14411447

1448+
SRV_ERR("received partial result, %s\n", result.to_json().dump().c_str());
1449+
14421450
if (result.stop != STOP_TYPE_NONE) {
14431451
if (++n_finished == id_tasks.size()) {
14441452
break;
@@ -1533,7 +1541,7 @@ struct server_context {
15331541
res.id = task.id;
15341542
res.n_idle_slots = n_idle_slots;
15351543
res.n_processing_slots = n_processing_slots;
1536-
res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
1544+
res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
15371545
res.t_start = metrics.t_start;
15381546

15391547
res.kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
@@ -1627,13 +1635,13 @@ struct server_context {
16271635
const double t_restore_ms = (t_end - t_start) / 1000.0;
16281636

16291637
server_task_result_slot_save_load result;
1630-
result.id = task.id;
1631-
result.id_slot = id_slot;
1632-
result.filename = filename;
1633-
result.is_save = false;
1634-
result.n_saved = token_count;
1635-
result.n_read = nread;
1636-
result.t_ms = t_restore_ms;
1638+
result.id = task.id;
1639+
result.id_slot = id_slot;
1640+
result.filename = filename;
1641+
result.is_save = false;
1642+
result.n_restored = token_count;
1643+
result.n_read = nread;
1644+
result.t_ms = t_restore_ms;
16371645
queue_results.send(result);
16381646
} break;
16391647
case SERVER_TASK_TYPE_SLOT_ERASE:

examples/server/server.hpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
using json = nlohmann::ordered_json;
1717

18+
// cast a shared_ptr to a specific type using copy constructor
1819
#define copy_cast_ptr(TYPEOUT, ptr) *(static_cast<TYPEOUT*>(ptr.get()))
1920

2021
enum stop_type {
@@ -281,23 +282,34 @@ struct server_task_result_cmpl_partial : server_task_result {
281282
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
282283
int index = 0;
283284
std::string content;
285+
286+
bool truncated;
284287
int32_t n_decoded;
285288
int32_t n_prompt_tokens;
289+
286290
stop_type stop = STOP_TYPE_NONE;
287291
std::vector<completion_token_output> probs_output;
288292
result_timings timings;
289293

290294
json to_json() {
295+
bool is_stop = stop != STOP_TYPE_NONE;
296+
// non-OAI-compat JSON
291297
json res = json {
292-
{"index", index},
293-
{"content", content},
294-
{"stop", stop != STOP_TYPE_NONE},
295-
{"id_slot", id_slot},
298+
{"index", index},
299+
{"content", content},
300+
{"stop_type", stop_type_to_str(stop)},
301+
{"stop", is_stop},
302+
{"id_slot", id_slot},
303+
{"tokens_predicted", n_decoded},
304+
{"tokens_evaluated", n_prompt_tokens},
296305
};
297-
// populate the timings object when timings_per_token is set
306+
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
298307
if (timings.prompt_n > 0) {
299308
res.push_back({"timings", timings.to_json()});
300309
}
310+
if (is_stop) {
311+
res.push_back({"truncated", truncated});
312+
}
301313
return res;
302314
}
303315

@@ -464,7 +476,7 @@ struct server_task_result_slot_erase : server_task_result {
464476
{ "n_erased", n_erased },
465477
};
466478
}
467-
479+
468480
static server_task_result_slot_erase from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
469481
return copy_cast_ptr(server_task_result_slot_erase, result_ptr);
470482
}

examples/server/tests/tests.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#!/bin/bash
22

3+
# make sure we are in the right directory
4+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
5+
cd $SCRIPT_DIR
6+
37
set -eu
48

59
if [ $# -lt 1 ]

examples/server/tests/unit/test_chat_completion.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ def create_server():
1212

1313

1414
@pytest.mark.parametrize(
15-
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
15+
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
1616
[
17-
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
18-
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
17+
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
18+
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
1919
]
2020
)
21-
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
21+
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
2222
global server
2323
server.start()
2424
res = server.make_request("POST", "/chat/completions", data={
@@ -35,10 +35,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
3535
choice = res.body["choices"][0]
3636
assert "assistant" == choice["message"]["role"]
3737
assert match_regex(re_content, choice["message"]["content"])
38-
if truncated:
39-
assert choice["finish_reason"] == "length"
40-
else:
41-
assert choice["finish_reason"] == "stop"
38+
assert choice["finish_reason"] == finish_reason
4239

4340

4441
@pytest.mark.parametrize(
@@ -93,7 +90,7 @@ def test_chat_completion_with_openai_library():
9390
temperature=0.8,
9491
)
9592
print(res)
96-
assert res.choices[0].finish_reason == "stop"
93+
assert res.choices[0].finish_reason == "length"
9794
assert res.choices[0].message.content is not None
9895
assert match_regex("(Suddenly)+", res.choices[0].message.content)
9996

0 commit comments

Comments
 (0)