Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ These words will not be included in the completion, so make sure to add them to

`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.

`requested_fields`: A list of required response fields, for example : `"requested_fields": ["content", "generation_settings/n_predict"]` If there is no field, return an empty json for that field.

**Response format**

- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
Expand Down
6 changes: 5 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ struct slot_params {
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit

std::vector<std::string> antiprompt;
std::vector<std::string> requested_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
bool ignore_eos = false;
Expand Down Expand Up @@ -209,6 +210,7 @@ struct server_task {
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.requested_fields = json_value(data, "requested_fields", std::vector<std::string>());

params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
Expand Down Expand Up @@ -522,6 +524,7 @@ struct server_task_result_cmpl_final : server_task_result {

bool post_sampling_probs;
std::vector<completion_token_output> probs_output;
std::vector<std::string> requested_fields;

slot_params generation_params;

Expand Down Expand Up @@ -568,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result {
if (!stream && !probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
}
return res;
return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res);
}

json to_json_oaicompat_chat() {
Expand Down Expand Up @@ -2063,6 +2066,7 @@ struct server_context {
res->tokens = slot.generated_tokens;
res->timings = slot.get_timings();
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
res->requested_fields = slot.params.requested_fields;

res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
Expand Down
34 changes: 34 additions & 0 deletions examples/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,40 @@ def check_slots_status():
# assert match_regex(re_content, res.body["content"])


@pytest.mark.parametrize(
"prompt,n_predict,requested_fields",
[
("I believe the meaning of life is", 8, []),
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
],
)
def test_completion_requested_fields(
prompt: str, n_predict: int, requested_fields: list[str]
):
global server
server.start()
res = server.make_request(
"POST",
"/completion",
data={
"n_predict": n_predict,
"prompt": prompt,
"requested_fields": requested_fields,
},
)
assert res.status_code == 200
assert "content" in res.body
assert len(res.body["content"])
if len(requested_fields):
assert res.body["generation_settings/n_predict"] == n_predict
assert res.body["prompt"] == "<s> " + prompt
assert isinstance(res.body["content"], str)
assert len(res.body) == len(requested_fields)
else:
assert len(res.body)
assert "generation_settings" in res.body


def test_n_probs():
global server
server.start()
Expand Down
22 changes: 22 additions & 0 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,28 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
return false;
}

// get value by path(key1 / key2)
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
json result = json::object();

for (const std::string & path : paths) {
json current = js;
const auto keys = string_split<std::string>(path, /*separator*/ '/');
bool valid_path = true;
for (const std::string & k : keys) {
if (valid_path && current.is_object() && current.contains(k)) {
current = current[k];
} else {
valid_path = false;
}
}
if (valid_path) {
result[path] = current;
}
}
return result;
}

/**
* this handles 2 cases:
* - only string, example: "string"
Expand Down
Loading