Skip to content

Commit ce5aed2

Browse files
committed
Merge branch 'concedo_experimental' into croco_nex_0
2 parents f4dc3a3 + 5eb314a commit ce5aed2

23 files changed

+25913
-13372
lines changed

examples/server/server.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ struct slot_params {
9292
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
9393

9494
std::vector<std::string> antiprompt;
95+
std::vector<std::string> response_fields;
9596
bool timings_per_token = false;
9697
bool post_sampling_probs = false;
9798
bool ignore_eos = false;
@@ -209,6 +210,7 @@ struct server_task {
209210
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
210211
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
211212
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
213+
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
212214

213215
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
214216
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
@@ -522,6 +524,7 @@ struct server_task_result_cmpl_final : server_task_result {
522524

523525
bool post_sampling_probs;
524526
std::vector<completion_token_output> probs_output;
527+
std::vector<std::string> response_fields;
525528

526529
slot_params generation_params;
527530

@@ -568,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result {
568571
if (!stream && !probs_output.empty()) {
569572
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
570573
}
571-
return res;
574+
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
572575
}
573576

574577
json to_json_oaicompat_chat() {
@@ -2066,6 +2069,7 @@ struct server_context {
20662069
res->tokens = slot.generated_tokens;
20672070
res->timings = slot.get_timings();
20682071
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
2072+
res->response_fields = slot.params.response_fields;
20692073

20702074
res->truncated = slot.truncated;
20712075
res->n_decoded = slot.n_decoded;
@@ -3786,6 +3790,17 @@ int main(int argc, char ** argv) {
37863790
return;
37873791
}
37883792

3793+
bool use_base64 = false;
3794+
if (body.count("encoding_format") != 0) {
3795+
const std::string& format = body.at("encoding_format");
3796+
if (format == "base64") {
3797+
use_base64 = true;
3798+
} else if (format != "float") {
3799+
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
3800+
return;
3801+
}
3802+
}
3803+
37893804
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
37903805
for (const auto & tokens : tokenized_prompts) {
37913806
// this check is necessary for models that do not add BOS token to the input
@@ -3837,7 +3852,7 @@ int main(int argc, char ** argv) {
38373852
}
38383853

38393854
// write JSON response
3840-
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
3855+
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
38413856
res_ok(res, root);
38423857
};
38433858

examples/server/tests/unit/test_completion.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_consistent_result_same_seed(n_slots: int):
9595
res = server.make_request("POST", "/completion", data={
9696
"prompt": "I believe the meaning of life is",
9797
"seed": 42,
98-
"temperature": 1.0,
98+
"temperature": 0.0,
9999
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
100100
})
101101
if last_res is not None:
@@ -120,9 +120,10 @@ def test_different_result_different_seed(n_slots: int):
120120
assert res.body["content"] != last_res.body["content"]
121121
last_res = res
122122

123-
123+
# TODO figure why it don't work with temperature = 1
124+
# @pytest.mark.parametrize("temperature", [0.0, 1.0])
124125
@pytest.mark.parametrize("n_batch", [16, 32])
125-
@pytest.mark.parametrize("temperature", [0.0, 1.0])
126+
@pytest.mark.parametrize("temperature", [0.0])
126127
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
127128
global server
128129
server.n_batch = n_batch
@@ -257,6 +258,40 @@ def check_slots_status():
257258
# assert match_regex(re_content, res.body["content"])
258259

259260

261+
@pytest.mark.parametrize(
262+
"prompt,n_predict,response_fields",
263+
[
264+
("I believe the meaning of life is", 8, []),
265+
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
266+
],
267+
)
268+
def test_completion_response_fields(
269+
prompt: str, n_predict: int, response_fields: list[str]
270+
):
271+
global server
272+
server.start()
273+
res = server.make_request(
274+
"POST",
275+
"/completion",
276+
data={
277+
"n_predict": n_predict,
278+
"prompt": prompt,
279+
"response_fields": response_fields,
280+
},
281+
)
282+
assert res.status_code == 200
283+
assert "content" in res.body
284+
assert len(res.body["content"])
285+
if len(response_fields):
286+
assert res.body["generation_settings/n_predict"] == n_predict
287+
assert res.body["prompt"] == "<s> " + prompt
288+
assert isinstance(res.body["content"], str)
289+
assert len(res.body) == len(response_fields)
290+
else:
291+
assert len(res.body)
292+
assert "generation_settings" in res.body
293+
294+
260295
def test_n_probs():
261296
global server
262297
server.start()

examples/server/tests/unit/test_embedding.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import base64
2+
import struct
13
import pytest
24
from openai import OpenAI
35
from utils import *
@@ -194,3 +196,42 @@ def test_embedding_usage_multiple():
194196
assert res.status_code == 200
195197
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
196198
assert res.body['usage']['prompt_tokens'] == 2 * 9
199+
200+
201+
def test_embedding_openai_library_base64():
202+
server.start()
203+
test_input = "Test base64 embedding output"
204+
205+
# get embedding in default format
206+
res = server.make_request("POST", "/v1/embeddings", data={
207+
"input": test_input
208+
})
209+
assert res.status_code == 200
210+
vec0 = res.body["data"][0]["embedding"]
211+
212+
# get embedding in base64 format
213+
res = server.make_request("POST", "/v1/embeddings", data={
214+
"input": test_input,
215+
"encoding_format": "base64"
216+
})
217+
218+
assert res.status_code == 200
219+
assert "data" in res.body
220+
assert len(res.body["data"]) == 1
221+
222+
embedding_data = res.body["data"][0]
223+
assert "embedding" in embedding_data
224+
assert isinstance(embedding_data["embedding"], str)
225+
226+
# Verify embedding is valid base64
227+
decoded = base64.b64decode(embedding_data["embedding"])
228+
# Verify decoded data can be converted back to float array
229+
float_count = len(decoded) // 4 # 4 bytes per float
230+
floats = struct.unpack(f'{float_count}f', decoded)
231+
assert len(floats) > 0
232+
assert all(isinstance(x, float) for x in floats)
233+
assert len(floats) == len(vec0)
234+
235+
# make sure the decoded data is the same as the original
236+
for x, y in zip(floats, vec0):
237+
assert abs(x - y) < EPSILON

examples/server/utils.hpp

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "common.h"
44
#include "log.h"
55
#include "llama.h"
6+
#include "common/base64.hpp"
67

78
#ifndef NDEBUG
89
// crash the server in debug mode, otherwise send an http 500 error
@@ -90,6 +91,28 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
9091
return false;
9192
}
9293

94+
// get value by path(key1 / key2)
95+
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
96+
json result = json::object();
97+
98+
for (const std::string & path : paths) {
99+
json current = js;
100+
const auto keys = string_split<std::string>(path, /*separator*/ '/');
101+
bool valid_path = true;
102+
for (const std::string & k : keys) {
103+
if (valid_path && current.is_object() && current.contains(k)) {
104+
current = current[k];
105+
} else {
106+
valid_path = false;
107+
}
108+
}
109+
if (valid_path) {
110+
result[path] = current;
111+
}
112+
}
113+
return result;
114+
}
115+
93116
/**
94117
* this handles 2 cases:
95118
* - only string, example: "string"
@@ -591,16 +614,31 @@ static json oaicompat_completion_params_parse(
591614
return llama_params;
592615
}
593616

594-
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
617+
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
595618
json data = json::array();
596619
int32_t n_tokens = 0;
597620
int i = 0;
598621
for (const auto & elem : embeddings) {
599-
data.push_back(json{
600-
{"embedding", json_value(elem, "embedding", json::array())},
601-
{"index", i++},
602-
{"object", "embedding"}
603-
});
622+
json embedding_obj;
623+
624+
if (use_base64) {
625+
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
626+
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
627+
size_t data_size = vec.size() * sizeof(float);
628+
embedding_obj = {
629+
{"embedding", base64::encode(data_ptr, data_size)},
630+
{"index", i++},
631+
{"object", "embedding"},
632+
{"encoding_format", "base64"}
633+
};
634+
} else {
635+
embedding_obj = {
636+
{"embedding", json_value(elem, "embedding", json::array())},
637+
{"index", i++},
638+
{"object", "embedding"}
639+
};
640+
}
641+
data.push_back(embedding_obj);
604642

605643
n_tokens += json_value(elem, "tokens_evaluated", 0);
606644
}

0 commit comments

Comments
 (0)