Skip to content
Merged
4 changes: 3 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
break;
case 0: // max absolute
for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
if (sum < std::abs(inp[i])) {
sum = std::abs(inp[i]);
}
}
sum /= 32760.0; // make an int16 range
break;
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
// Embedding utils
//

void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
// TODO: repace embd_norm with an enum
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);

float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);

Expand Down
2 changes: 1 addition & 1 deletion examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}

std::vector<float> emb_norm(emb_unorm.size());
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
result.push_back(emb_norm);

#ifdef GRIT_DEBUG
Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

float * out = output + batch.seq_id[i][0] * n_embd;
common_embd_normalize(embd, out, n_embd);
common_embd_normalize(embd, out, n_embd, 2);
}
}

Expand Down
42 changes: 42 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ curl http://localhost:8080/v1/chat/completions \

### POST `/v1/embeddings`: OpenAI-compatible embeddings API

This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.

*Options:*

See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
Expand Down Expand Up @@ -795,6 +797,46 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
}'
```

### POST `/embeddings`: non-OpenAI-compatible embeddings API

This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm.

Note that the response format of this endpoint is different from `/v1/embeddings`.

*Options:*

Same as the `/v1/embeddings` endpoint.

*Examples:*

Same as the `/v1/embeddings` endpoint.

**Response format**

```json
[
{
"index": 0,
"embedding": [
[ ... embeddings for token 0 ... ],
[ ... embeddings for token 1 ... ],
[ ... ]
[ ... embeddings for token N-1 ... ],
]
},
...
{
"index": P,
"embedding": [
[ ... embeddings for token 0 ... ],
[ ... embeddings for token 1 ... ],
[ ... ]
[ ... embeddings for token N-1 ... ],
]
}
]
```

### GET `/slots`: Returns the current slots processing state

> [!WARNING]
Expand Down
75 changes: 57 additions & 18 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ struct server_task_result_cmpl_partial : server_task_result {
{"delta",
json {
{"content", content},
{"tokens", tokens}
}},
}});
}
Expand All @@ -726,18 +727,32 @@ struct server_task_result_cmpl_partial : server_task_result {

struct server_task_result_embd : server_task_result {
int index = 0;
std::vector<float> embedding;
std::vector<std::vector<float>> embedding;

int32_t n_tokens;

// OAI-compat fields
bool oaicompat = false;

virtual int get_index() override {
return index;
}

virtual json to_json() override {
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
}

json to_json_non_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
};
}

json to_json_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
{"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
};
}
Expand Down Expand Up @@ -2017,9 +2032,10 @@ struct server_context {

void send_embedding(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat;

const int n_embd = llama_n_embd(model);

Expand All @@ -2038,12 +2054,18 @@ struct server_context {
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);

res->embedding = std::vector<float>(n_embd, 0.0f);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
}

common_embd_normalize(embd, embd_res.data(), n_embd);
res->embedding = embd_res;
// normalize only when there is pooling
// TODO: configurable
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
res->embedding.push_back(embd_res);
} else {
res->embedding.push_back({ embd, embd + n_embd });
}
}

SLT_DBG(slot, "%s", "sending embeddings\n");
Expand Down Expand Up @@ -2657,7 +2679,10 @@ struct server_context {

// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;

common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
Expand Down Expand Up @@ -3665,14 +3690,17 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};

const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
const json body = json::parse(req.body);
bool oaicompat = false;

if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}

// for the shape of input/content, see tokenize_input_prompts()
json prompt;
if (body.contains("input")) {
oaicompat = true;
if (body.count("input") != 0) {
prompt = body.at("input");
} else if (body.contains("content")) {
oaicompat = false;
Expand All @@ -3697,10 +3725,15 @@ int main(int argc, char ** argv) {
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);

task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);

// OAI-compat
task.params.oaicompat = oaicompat;

tasks.push_back(task);
}

Expand Down Expand Up @@ -3728,12 +3761,18 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root = oaicompat
? format_embeddings_response_oaicompat(body, responses)
: responses.size() == 1 ? responses[0] : json(responses);
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
res_ok(res, root);
};

const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, false);
};

const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, true);
};

const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
Expand Down Expand Up @@ -3907,7 +3946,7 @@ int main(int argc, char ** argv) {
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings_oai);
svr->Post("/rerank", handle_rerank);
svr->Post("/reranking", handle_rerank);
svr->Post("/v1/rerank", handle_rerank);
Expand Down
Loading
Loading