Skip to content

Commit 0a753fb

Browse files
committed
add support for base64
1 parent 30caac3 commit 0a753fb

File tree

4 files changed

+63
-7
lines changed

4 files changed

+63
-7
lines changed

examples/server/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ endforeach()
3434
add_executable(${TARGET} ${TARGET_SRCS})
3535
install(TARGETS ${TARGET} RUNTIME)
3636

37+
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
3738
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
3839

3940
if (LLAMA_SERVER_SSL)

examples/server/server.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3786,6 +3786,17 @@ int main(int argc, char ** argv) {
37863786
return;
37873787
}
37883788

3789+
bool use_base64 = false;
3790+
if (body.count("encoding_format") != 0) {
3791+
const std::string& format = body.at("encoding_format");
3792+
if (format == "base64") {
3793+
use_base64 = true;
3794+
} else if (format != "float") {
3795+
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
3796+
return;
3797+
}
3798+
}
3799+
37893800
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
37903801
for (const auto & tokens : tokenized_prompts) {
37913802
// this check is necessary for models that do not add BOS token to the input
@@ -3837,7 +3848,7 @@ int main(int argc, char ** argv) {
38373848
}
38383849

38393850
// write JSON response
3840-
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
3851+
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
38413852
res_ok(res, root);
38423853
};
38433854

examples/server/tests/unit/test_embedding.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import base64
2+
import struct
13
import pytest
24
from openai import OpenAI
35
from utils import *
@@ -194,3 +196,32 @@ def test_embedding_usage_multiple():
194196
assert res.status_code == 200
195197
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
196198
assert res.body['usage']['prompt_tokens'] == 2 * 9
199+
200+
201+
def test_embedding_openai_library_base64():
202+
server.start()
203+
test_input = "Test base64 embedding output"
204+
205+
res = server.make_request("POST", "/embeddings", data={
206+
"input": test_input,
207+
"encoding_format": "base64"
208+
})
209+
210+
assert res.status_code == 200
211+
assert "data" in res.body
212+
assert len(res.body["data"]) == 1
213+
214+
embedding_data = res.body["data"][0]
215+
assert "embedding" in embedding_data
216+
assert isinstance(embedding_data["embedding"], str)
217+
218+
# Verify embedding is valid base64
219+
try:
220+
decoded = base64.b64decode(embedding_data["embedding"])
221+
# Verify decoded data can be converted back to float array
222+
float_count = len(decoded) // 4 # 4 bytes per float
223+
floats = struct.unpack(f'{float_count}f', decoded)
224+
assert len(floats) > 0
225+
assert all(isinstance(x, float) for x in floats)
226+
except Exception as e:
227+
pytest.fail(f"Invalid base64 format: {str(e)}")

examples/server/utils.hpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -591,16 +591,29 @@ static json oaicompat_completion_params_parse(
591591
return llama_params;
592592
}
593593

594-
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
594+
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
595595
json data = json::array();
596596
int32_t n_tokens = 0;
597597
int i = 0;
598598
for (const auto & elem : embeddings) {
599-
data.push_back(json{
600-
{"embedding", json_value(elem, "embedding", json::array())},
601-
{"index", i++},
602-
{"object", "embedding"}
603-
});
599+
if (use_base64) {
600+
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
601+
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
602+
size_t data_size = vec.size() * sizeof(float);
603+
embedding_obj = {
604+
{"embedding", base64::encode(data_ptr, data_size)},
605+
{"index", i++},
606+
{"object", "embedding"},
607+
{"encoding_format", "base64"}
608+
};
609+
} else {
610+
embedding_obj = {
611+
{"embedding", json_value(elem, "embedding", json::array())},
612+
{"index", i++},
613+
{"object", "embedding"}
614+
};
615+
}
616+
data.push_back(embedding_obj);
604617

605618
n_tokens += json_value(elem, "tokens_evaluated", 0);
606619
}

0 commit comments

Comments
 (0)