From 2e04ccf4e66a56eade51c2b62d7fe9026021fbb9 Mon Sep 17 00:00:00 2001 From: nvrxq Date: Wed, 18 Dec 2024 01:21:44 +0300 Subject: [PATCH 1/6] llama_server_response_fields --- examples/server/server.cpp | 6 +++++- examples/server/utils.hpp | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 436170a034fde..bc179cfb5effd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -91,6 +91,7 @@ struct slot_params { int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit std::vector antiprompt; + std::vector requested_fields; bool timings_per_token = false; bool ignore_eos = false; @@ -205,6 +206,7 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.requested_fields = json_value(data, "requested_fields", std::vector()); params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); @@ -482,6 +484,7 @@ struct server_task_result_cmpl_final : server_task_result { stop_type stop = STOP_TYPE_NONE; std::vector probs_output; + std::vector requested_fields; slot_params generation_params; @@ -527,7 +530,7 @@ struct server_task_result_cmpl_final : server_task_result { if (!probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); } - return res; + return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res); } json to_json_oaicompat_chat() { @@ -1960,6 +1963,7 @@ struct server_context { res->content = slot.generated_text; res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->requested_fields = slot.params.requested_fields; res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8fffe484aec12..0ac8b2cce8478 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -88,6 +88,33 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) { return false; } +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector& paths, const json& js) { + json result = json::object(); + + for (const std::string& path : paths) { + json current = js; + std::istringstream stream(path); + std::string key; + std::vector keys; + while (std::getline(stream, key, '/')) { + keys.push_back(key); + } + bool valid_path = true; + for (const std::string& k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } + return result; +} + /** * this handles 2 cases: * - only string, example: "string" From bc09b1acdf18ca199489938abc93bfc934552019 Mon Sep 17 00:00:00 2001 From: nvrxq Date: Sun, 22 Dec 2024 18:57:55 +0300 Subject: [PATCH 2/6] llama_server_response_fields_fix_issues --- examples/server/README.md | 2 ++ examples/server/tests/unit/test_completion.py | 36 +++++++++++++++++++ examples/server/utils.hpp | 15 +++----- 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 63a7bf43a920d..ccb40dba39ede 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -442,6 +442,8 @@ These words will not be included in the completion, so make sure to add them to `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` +`requested_fields`: A list of required response fields, for example : `"requested_fields": ["content", "generation_settings/n_predict"]` If there is no field, return an empty json for that field. + **Response format** - Note: In streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 062ebcd4a05cc..83d1a5d779635 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -249,6 +249,42 @@ def check_slots_status(): # assert match_regex(re_content, res.body["content"]) +@pytest.mark.parametrize( + "prompt,n_predict,requested_fields", + [ + ("I believe the meaning of life is", 8, []), + ( + "I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"], + ), + ], +) +def test_completion_requested_fields( + prompt: str, n_predict: int, requested_fields: list[str] +): + global server + server.start() + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": n_predict, + "prompt": prompt, + "requested_fields": requested_fields, + }, + ) + assert res.status_code == 200 + assert "content" in res.body + assert len(res.body["content"]) + if len(requested_fields) > 0: + assert res.body["generation_settings/n_predict"] == n_predict + assert res.body["prompt"] == " " + prompt + assert isinstance(res.body["content"], str) + assert len(res.body) == len(requested_fields) + else: + assert len(res.body) > 0 + assert "generation_settings" in res.body + + def test_n_probs(): global server server.start() diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 0ac8b2cce8478..9ad9000672919 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -89,19 +89,14 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) { } // get value by path(key1 / key2) -static json json_get_nested_values(const std::vector& paths, const json& js) { +static json json_get_nested_values(const std::vector & paths, const json & js) { json result = json::object(); - - for (const std::string& path : paths) { + + for (const std::string & path : paths) { json current = js; - std::istringstream stream(path); - std::string key; - std::vector keys; - while (std::getline(stream, key, '/')) { - keys.push_back(key); - } + const auto keys = string_split(path, /*delim*/ '/'); bool valid_path = true; - for (const std::string& k : keys) { + for (const std::string & k : keys) { if (valid_path && current.is_object() && current.contains(k)) { current = current[k]; } else { From 0958ee96ac464f80d22d59bcd0b3593a0a2149be Mon Sep 17 00:00:00 2001 From: nvrxq Date: Sun, 22 Dec 2024 19:16:28 +0300 Subject: [PATCH 3/6] params fixes --- examples/server/tests/unit/test_completion.py | 4 ++-- examples/server/utils.hpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 1a6c7797429be..ee65901f15cdb 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -283,13 +283,13 @@ def test_completion_requested_fields( assert res.status_code == 200 assert "content" in res.body assert len(res.body["content"]) - if len(requested_fields) > 0: + if len(requested_fields): assert res.body["generation_settings/n_predict"] == n_predict assert res.body["prompt"] == " " + prompt assert isinstance(res.body["content"], str) assert len(res.body) == len(requested_fields) else: - assert len(res.body) > 0 + assert len(res.body) assert "generation_settings" in res.body diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e5164a8895a52..d0e8d5266ae1e 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -94,7 +94,7 @@ static json json_get_nested_values(const std::vector & paths, const for (const std::string & path : paths) { json current = js; - const auto keys = string_split(path, /*delim*/ '/'); + const auto keys = string_split(path, /*separator*/ '/'); bool valid_path = true; for (const std::string & k : keys) { if (valid_path && current.is_object() && current.contains(k)) { From 3d3c6bae46417cdd572c6b3f2a3e132dc004ca31 Mon Sep 17 00:00:00 2001 From: nvrxq Date: Sun, 22 Dec 2024 19:18:54 +0300 Subject: [PATCH 4/6] fix --- examples/server/tests/unit/test_completion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index ee65901f15cdb..f7a427c337c8b 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -261,9 +261,7 @@ def check_slots_status(): "prompt,n_predict,requested_fields", [ ("I believe the meaning of life is", 8, []), - ( - "I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"], - ), + ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]), ], ) def test_completion_requested_fields( From 4cf1fef3207d57b2ab39a6956b0acf77777e4bb5 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 24 Dec 2024 16:26:46 +0100 Subject: [PATCH 5/6] clarify docs --- examples/server/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/README.md b/examples/server/README.md index a1830c098e6bf..033a514ede149 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -450,7 +450,7 @@ These words will not be included in the completion, so make sure to add them to `post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain. -`requested_fields`: A list of required response fields, for example : `"requested_fields": ["content", "generation_settings/n_predict"]` If there is no field, return an empty json for that field. +`requested_fields`: A list of response fields, for example: `"requested_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. **Response format** From b8679c0bb5d37952163458ae699fb931de54d959 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 24 Dec 2024 16:28:44 +0100 Subject: [PATCH 6/6] change to "response_fields" --- examples/server/README.md | 2 +- examples/server/server.cpp | 10 +++++----- examples/server/tests/unit/test_completion.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 033a514ede149..958d1cdd14fb8 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -450,7 +450,7 @@ These words will not be included in the completion, so make sure to add them to `post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain. -`requested_fields`: A list of response fields, for example: `"requested_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. +`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. **Response format** diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7b277b9dcc6da..4affc7cde7816 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -92,7 +92,7 @@ struct slot_params { int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit std::vector antiprompt; - std::vector requested_fields; + std::vector response_fields; bool timings_per_token = false; bool post_sampling_probs = false; bool ignore_eos = false; @@ -210,7 +210,7 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.requested_fields = json_value(data, "requested_fields", std::vector()); + params.response_fields = json_value(data, "response_fields", std::vector()); params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); @@ -524,7 +524,7 @@ struct server_task_result_cmpl_final : server_task_result { bool post_sampling_probs; std::vector probs_output; - std::vector requested_fields; + std::vector response_fields; slot_params generation_params; @@ -571,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result { if (!stream && !probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); } - return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res); + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); } json to_json_oaicompat_chat() { @@ -2066,7 +2066,7 @@ struct server_context { res->tokens = slot.generated_tokens; res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); - res->requested_fields = slot.params.requested_fields; + res->response_fields = slot.params.response_fields; res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index f7a427c337c8b..00d5ce391d8f0 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -258,14 +258,14 @@ def check_slots_status(): @pytest.mark.parametrize( - "prompt,n_predict,requested_fields", + "prompt,n_predict,response_fields", [ ("I believe the meaning of life is", 8, []), ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]), ], ) -def test_completion_requested_fields( - prompt: str, n_predict: int, requested_fields: list[str] +def test_completion_response_fields( + prompt: str, n_predict: int, response_fields: list[str] ): global server server.start() @@ -275,17 +275,17 @@ def test_completion_requested_fields( data={ "n_predict": n_predict, "prompt": prompt, - "requested_fields": requested_fields, + "response_fields": response_fields, }, ) assert res.status_code == 200 assert "content" in res.body assert len(res.body["content"]) - if len(requested_fields): + if len(response_fields): assert res.body["generation_settings/n_predict"] == n_predict assert res.body["prompt"] == " " + prompt assert isinstance(res.body["content"], str) - assert len(res.body) == len(requested_fields) + assert len(res.body) == len(response_fields) else: assert len(res.body) assert "generation_settings" in res.body