Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ endforeach()
add_executable(${TARGET} ${TARGET_SRCS})
install(TARGETS ${TARGET} RUNTIME)

target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})

if (LLAMA_SERVER_SSL)
Expand Down
13 changes: 12 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3786,6 +3786,17 @@ int main(int argc, char ** argv) {
return;
}

bool use_base64 = false;
if (body.count("encoding_format") != 0) {
const std::string& format = body.at("encoding_format");
if (format == "base64") {
use_base64 = true;
} else if (format != "float") {
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
return;
}
}

std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
Expand Down Expand Up @@ -3837,7 +3848,7 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
res_ok(res, root);
};

Expand Down
31 changes: 31 additions & 0 deletions examples/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import struct
import pytest
from openai import OpenAI
from utils import *
Expand Down Expand Up @@ -194,3 +196,32 @@ def test_embedding_usage_multiple():
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == 2 * 9


def test_embedding_openai_library_base64():
server.start()
test_input = "Test base64 embedding output"

res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input,
"encoding_format": "base64"
})

assert res.status_code == 200
assert "data" in res.body
assert len(res.body["data"]) == 1

embedding_data = res.body["data"][0]
assert "embedding" in embedding_data
assert isinstance(embedding_data["embedding"], str)

# Verify embedding is valid base64
try:
decoded = base64.b64decode(embedding_data["embedding"])
# Verify decoded data can be converted back to float array
float_count = len(decoded) // 4 # 4 bytes per float
floats = struct.unpack(f'{float_count}f', decoded)
assert len(floats) > 0
assert all(isinstance(x, float) for x in floats)
except Exception as e:
pytest.fail(f"Invalid base64 format: {str(e)}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bad practice to use try..catch inside a test, because if one of the assert fails, it won't let you know it happens at exactly which line of code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson, thanks for the hint. my bad, I forgot to remove it.

28 changes: 22 additions & 6 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "common/base64.hpp"

#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
Expand Down Expand Up @@ -591,16 +592,31 @@ static json oaicompat_completion_params_parse(
return llama_params;
}

static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & elem : embeddings) {
data.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
json embedding_obj;

if (use_base64) {
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
size_t data_size = vec.size() * sizeof(float);
embedding_obj = {
{"embedding", base64::encode(data_ptr, data_size)},
{"index", i++},
{"object", "embedding"},
{"encoding_format", "base64"}
};
} else {
embedding_obj = {
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
};
}
data.push_back(embedding_obj);

n_tokens += json_value(elem, "tokens_evaluated", 0);
}
Expand Down
Loading