Skip to content

Commit 374b648

Browse files
committed
server / ranking : add sorting and management of top_n
1 parent 5113efd commit 374b648

File tree

2 files changed

+47
-45
lines changed

2 files changed

+47
-45
lines changed

tools/server/server.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5061,13 +5061,13 @@ int main(int argc, char ** argv) {
50615061
const json body = json::parse(req.body);
50625062

50635063
// TODO: implement
5064-
//int top_n = 1;
5065-
//if (body.count("top_n") != 1) {
5066-
// top_n = body.at("top_n");
5067-
//} else {
5068-
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
5069-
// return;
5070-
//}
5064+
int top_n = 1;
5065+
if (body.count("top_n") == 1) {
5066+
top_n = body.at("top_n");
5067+
} else {
5068+
res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
5069+
return;
5070+
}
50715071

50725072
// if true, use TEI API format, otherwise use Jina API format
50735073
// Jina: https://jina.ai/reranker/
@@ -5133,7 +5133,8 @@ int main(int argc, char ** argv) {
51335133
body,
51345134
responses,
51355135
is_tei_format,
5136-
documents);
5136+
documents,
5137+
top_n);
51375138

51385139
res_ok(res, root);
51395140
};

tools/server/utils.hpp

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -849,46 +849,47 @@ static json format_response_rerank(
849849
const json & request,
850850
const json & ranks,
851851
bool is_tei_format,
852-
std::vector<std::string> & texts) {
853-
json res;
854-
if (is_tei_format) {
855-
// TEI response format
856-
res = json::array();
857-
bool return_text = json_value(request, "return_text", false);
858-
for (const auto & rank : ranks) {
859-
int index = json_value(rank, "index", 0);
860-
json elem = json{
861-
{"index", index},
862-
{"score", json_value(rank, "score", 0.0)},
863-
};
864-
if (return_text) {
865-
elem["text"] = std::move(texts[index]);
866-
}
867-
res.push_back(elem);
868-
}
869-
} else {
870-
// Jina response format
871-
json results = json::array();
872-
int32_t n_tokens = 0;
873-
for (const auto & rank : ranks) {
874-
results.push_back(json{
875-
{"index", json_value(rank, "index", 0)},
876-
{"relevance_score", json_value(rank, "score", 0.0)},
877-
});
878-
879-
n_tokens += json_value(rank, "tokens_evaluated", 0);
852+
std::vector<std::string> & texts,
853+
int top_n) {
854+
json results;
855+
int32_t n_tokens = 0;
856+
bool return_text = is_tei_format && json_value(request, "return_text", false);
857+
std::vector<json> elements; // Temporary vector to hold unsorted elements
858+
std::string score_label = is_tei_format ? "score" : "relevance_score";
859+
for (const auto & rank : ranks) {
860+
int index = json_value(rank, "index", 0);
861+
json elem = json{
862+
{"index", index},
863+
{score_label, json_value(rank, "score", 0.0)},
864+
};
865+
n_tokens += json_value(rank, "tokens_evaluated", 0);
866+
if (return_text) {
867+
elem["text"] = std::move(texts[index]);
880868
}
869+
elements.push_back(elem);
870+
}
871+
872+
std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) {
873+
return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0);
874+
});
881875

882-
res = json{
883-
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
884-
{"object", "list"},
885-
{"usage", json{
886-
{"prompt_tokens", n_tokens},
887-
{"total_tokens", n_tokens}
888-
}},
889-
{"results", results}
890-
};
876+
results = json::array();
877+
int count = 0;
878+
for (const auto & elem : elements) {
879+
if (++count > top_n) break;
880+
results.push_back(elem);
891881
}
882+
if (is_tei_format) return results;
883+
884+
json res = json{
885+
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
886+
{"object", "list"},
887+
{"usage", json{
888+
{"prompt_tokens", n_tokens},
889+
{"total_tokens", n_tokens}
890+
}},
891+
{"results", results}
892+
};
892893

893894
return res;
894895
}

0 commit comments

Comments
 (0)