diff --git a/tools/server/server.cpp b/tools/server/server.cpp index de6e1a322b2c2..3948637f7e47c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5080,15 +5080,6 @@ int main(int argc, char ** argv) { const json body = json::parse(req.body); - // TODO: implement - //int top_n = 1; - //if (body.count("top_n") != 1) { - // top_n = body.at("top_n"); - //} else { - // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - // return; - //} - // if true, use TEI API format, otherwise use Jina API format // Jina: https://jina.ai/reranker/ // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank @@ -5113,6 +5104,8 @@ int main(int argc, char ** argv) { return; } + int top_n = json_value(body, "top_n", (int)documents.size()); + // create and queue the task json responses = json::array(); bool error = false; @@ -5153,7 +5146,8 @@ int main(int argc, char ** argv) { body, responses, is_tei_format, - documents); + documents, + top_n); res_ok(res, root); }; diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 4ca1423aaf2d4..4fda0410e340e 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -849,47 +849,44 @@ static json format_response_rerank( const json & request, const json & ranks, bool is_tei_format, - std::vector & texts) { - json res; - if (is_tei_format) { - // TEI response format - res = json::array(); - bool return_text = json_value(request, "return_text", false); - for (const auto & rank : ranks) { - int index = json_value(rank, "index", 0); - json elem = json{ - {"index", index}, - {"score", json_value(rank, "score", 0.0)}, - }; - if (return_text) { - elem["text"] = std::move(texts[index]); - } - res.push_back(elem); - } - } else { - // Jina response format - json results = json::array(); - int32_t n_tokens = 0; - for (const auto & rank : ranks) { - results.push_back(json{ - {"index", json_value(rank, "index", 0)}, - {"relevance_score", json_value(rank, "score", 0.0)}, - }); - - n_tokens += json_value(rank, "tokens_evaluated", 0); - } - - res = json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", json{ - {"prompt_tokens", n_tokens}, - {"total_tokens", n_tokens} - }}, - {"results", results} + std::vector & texts, + int top_n) { + int32_t n_tokens = 0; + bool return_text = is_tei_format && json_value(request, "return_text", false); + std::vector elements; // Temporary vector to hold unsorted elements + std::string score_label = is_tei_format ? "score" : "relevance_score"; + for (const auto & rank : ranks) { + int index = json_value(rank, "index", 0); + json elem = json{ + {"index", index}, + {score_label, json_value(rank, "score", 0.0)}, }; + n_tokens += json_value(rank, "tokens_evaluated", 0); + if (return_text) { + elem["text"] = std::move(texts[index]); + } + elements.push_back(elem); } + std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) { + return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0); + }); + + elements.resize(std::min(top_n, (int)elements.size())); + json results = elements; + + if (is_tei_format) return results; + + json res = json{ + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json{ + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"results", results} + }; + return res; }