Skip to content

Commit fb4b9be

Browse files
committed
fix model_alias and completion_probabilities
1 parent a43e1dc commit fb4b9be

File tree

5 files changed

+73
-31
lines changed

5 files changed

+73
-31
lines changed

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ struct common_params {
215215
struct common_params_speculative speculative;
216216

217217
std::string model = ""; // model path // NOLINT
218-
std::string model_alias = "unknown"; // model alias // NOLINT
218+
std::string model_alias = ""; // model alias // NOLINT
219219
std::string model_url = ""; // model url to download // NOLINT
220220
std::string hf_token = ""; // HF token // NOLINT
221221
std::string hf_repo = ""; // HF repo // NOLINT

examples/server/server.cpp

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -250,29 +250,29 @@ struct completion_token_output {
250250
std::string text_to_send;
251251
struct token_prob {
252252
llama_token tok;
253+
std::string tok_str;
253254
float prob;
254255
};
255256
std::vector<token_prob> probs;
256257

257-
json to_json(const llama_context * ctx) const {
258+
json to_json() const {
258259
json probs_for_token = json::array();
259260
for (const auto & p : probs) {
260-
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
261261
probs_for_token.push_back(json {
262-
{"tok_str", tok_str},
262+
{"tok_str", p.tok_str},
263263
{"prob", p.prob},
264264
});
265265
}
266266
return probs_for_token;
267267
}
268268

269-
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
269+
static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
270270
json out = json::array();
271271
for (const auto & prob : probs) {
272-
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
272+
const std::string tok_str = prob.text_to_send;
273273
out.push_back(json {
274274
{"content", tok_str},
275-
{"probs", prob.to_json(ctx)},
275+
{"probs", prob.to_json()},
276276
});
277277
}
278278
return out;
@@ -309,7 +309,7 @@ struct server_task_result_cmpl_final : server_task_result {
309309

310310
virtual json to_json() override {
311311
// non-OAI-compat JSON
312-
return json {
312+
json res = json {
313313
{"index", index},
314314
{"content", content},
315315
{"id_slot", id_slot},
@@ -326,6 +326,10 @@ struct server_task_result_cmpl_final : server_task_result {
326326
{"tokens_cached", n_tokens_cached},
327327
{"timings", timings.to_json()},
328328
};
329+
if (!probs_output.empty()) {
330+
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
331+
}
332+
return res;
329333
}
330334

331335
virtual json to_json_oai_compat() override {
@@ -362,12 +366,6 @@ struct server_task_result_cmpl_final : server_task_result {
362366
if (verbose) {
363367
res["__verbose"] = to_json();
364368
}
365-
366-
// TODO: fix this
367-
// if (result.contains("completion_probabilities")) {
368-
// res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
369-
// }
370-
371369
if (timings.prompt_n >= 0) {
372370
res.push_back({"timings", timings.to_json()});
373371
}
@@ -418,6 +416,9 @@ struct server_task_result_cmpl_partial : server_task_result {
418416
if (timings.prompt_n > 0) {
419417
res.push_back({"timings", timings.to_json()});
420418
}
419+
if (!probs_output.empty()) {
420+
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
421+
}
421422
if (is_stop) {
422423
res.push_back({"truncated", truncated});
423424
}
@@ -2786,9 +2787,11 @@ struct server_context {
27862787
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
27872788

27882789
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
2790+
auto tok_id = cur_p->data[i].id;
27892791
result.probs.push_back({
2790-
cur_p->data[i].id,
2791-
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2792+
tok_id,
2793+
tokens_to_output_formatted_string(ctx, tok_id),
2794+
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
27922795
});
27932796
}
27942797

@@ -2920,10 +2923,6 @@ int main(int argc, char ** argv) {
29202923
// struct that contains llama context and inference
29212924
server_context ctx_server;
29222925

2923-
if (params.model_alias == "unknown") {
2924-
params.model_alias = params.model;
2925-
}
2926-
29272926
llama_backend_init();
29282927
llama_numa_init(params.numa);
29292928

examples/server/tests/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,10 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
4444
DEBUG=1 ./tests.sh -s -v -x
4545
```
4646

47+
Hint: You can compile and run test in single command, useful for local developement:
48+
49+
```shell
50+
cmake --build build -j --target llama-server && ./examples/server/tests/tests.sh
51+
```
52+
4753
To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)

examples/server/tests/unit/test_chat_completion.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def create_server():
1414
@pytest.mark.parametrize(
1515
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
1616
[
17-
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
17+
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
1818
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
1919
]
2020
)
@@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
3030
],
3131
})
3232
assert res.status_code == 200
33+
assert res.body["model"] == model if model is not None else server.model_alias
3334
assert res.body["usage"]["prompt_tokens"] == n_prompt
3435
assert res.body["usage"]["completion_tokens"] == n_predicted
3536
choice = res.body["choices"][0]
@@ -39,17 +40,17 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
3940

4041

4142
@pytest.mark.parametrize(
42-
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
43+
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
4344
[
44-
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
45-
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
45+
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
46+
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
4647
]
4748
)
48-
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
49+
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
4950
global server
51+
server.model_alias = None
5052
server.start()
5153
res = server.make_stream_request("POST", "/chat/completions", data={
52-
"model": model,
5354
"max_tokens": max_tokens,
5455
"messages": [
5556
{"role": "system", "content": system_prompt},
@@ -60,16 +61,13 @@ def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, r
6061
content = ""
6162
for data in res:
6263
choice = data["choices"][0]
64+
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
6365
if choice["finish_reason"] in ["stop", "length"]:
6466
assert data["usage"]["prompt_tokens"] == n_prompt
6567
assert data["usage"]["completion_tokens"] == n_predicted
6668
assert "content" not in choice["delta"]
6769
assert match_regex(re_content, content)
68-
# FIXME: not sure why this is incorrect in stream mode
69-
# if truncated:
70-
# assert choice["finish_reason"] == "length"
71-
# else:
72-
# assert choice["finish_reason"] == "stop"
70+
assert choice["finish_reason"] == finish_reason
7371
else:
7472
assert choice["finish_reason"] is None
7573
content += choice["delta"]["content"]

examples/server/tests/unit/test_completion.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,24 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
5151
content += data["content"]
5252

5353

54+
def test_completion_stream_vs_non_stream():
55+
global server
56+
server.start()
57+
res_stream = server.make_stream_request("POST", "/completion", data={
58+
"n_predict": 8,
59+
"prompt": "I believe the meaning of life is",
60+
"stream": True,
61+
})
62+
res_non_stream = server.make_request("POST", "/completion", data={
63+
"n_predict": 8,
64+
"prompt": "I believe the meaning of life is",
65+
})
66+
content_stream = ""
67+
for data in res_stream:
68+
content_stream += data["content"]
69+
assert content_stream == res_non_stream.body["content"]
70+
71+
5472
@pytest.mark.parametrize("n_slots", [1, 2])
5573
def test_consistent_result_same_seed(n_slots: int):
5674
global server
@@ -221,3 +239,24 @@ def check_slots_status():
221239
assert len(res.body["content"]) > 10
222240
# FIXME: the result is not deterministic when using other slot than slot 0
223241
# assert match_regex(re_content, res.body["content"])
242+
243+
244+
def test_n_probs():
245+
global server
246+
server.start()
247+
res = server.make_request("POST", "/completion", data={
248+
"prompt": "I believe the meaning of life is",
249+
"n_probs": 10,
250+
"temperature": 0.0,
251+
"n_predict": 5,
252+
})
253+
assert res.status_code == 200
254+
assert "completion_probabilities" in res.body
255+
assert len(res.body["completion_probabilities"]) == 5
256+
for tok in res.body["completion_probabilities"]:
257+
assert "probs" in tok
258+
assert len(tok["probs"]) == 10
259+
for prob in tok["probs"]:
260+
assert "prob" in prob
261+
assert "tok_str" in prob
262+
assert 0.0 <= prob["prob"] <= 1.0

0 commit comments

Comments
 (0)