Skip to content

Commit 88cc971

Browse files
committed
server : fill usage info in reranking response
1 parent 357a7ba commit 88cc971

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

examples/server/server.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,14 +738,17 @@ struct server_task_result_rerank : server_task_result {
738738
int index = 0;
739739
float score = -1e6;
740740

741+
int32_t n_tokens;
742+
741743
virtual int get_index() override {
742744
return index;
743745
}
744746

745747
virtual json to_json() override {
746748
return json {
747-
{"index", index},
748-
{"score", score},
749+
{"index", index},
750+
{"score", score},
751+
{"tokens_evaluated", n_tokens},
749752
};
750753
}
751754
};
@@ -2034,6 +2037,7 @@ struct server_context {
20342037
auto res = std::make_unique<server_task_result_rerank>();
20352038
res->id = slot.id_task;
20362039
res->index = slot.index;
2040+
res->n_tokens = slot.n_prompt_tokens;
20372041

20382042
for (int i = 0; i < batch.n_tokens; ++i) {
20392043
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {

examples/server/tests/unit/test_rerank.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,26 @@ def test_invalid_rerank_req(documents):
5353
})
5454
assert res.status_code == 400
5555
assert "error" in res.body
56+
57+
58+
@pytest.mark.parametrize(
59+
"query,doc1,doc2,n_tokens",
60+
[
61+
("Machine learning is", "A machine", "Learning is", 19),
62+
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
63+
]
64+
)
65+
def test_rerank_usage(query, doc1, doc2, n_tokens):
66+
global server
67+
server.start()
68+
69+
res = server.make_request("POST", "/rerank", data={
70+
"query": query,
71+
"documents": [
72+
doc1,
73+
doc2,
74+
]
75+
})
76+
assert res.status_code == 200
77+
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
78+
assert res.body['usage']['prompt_tokens'] == n_tokens

examples/server/utils.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,20 +587,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
587587

588588
static json format_response_rerank(const json & request, const json & ranks) {
589589
json data = json::array();
590+
int32_t n_tokens = 0;
590591
int i = 0;
591592
for (const auto & rank : ranks) {
592593
data.push_back(json{
593594
{"index", i++},
594595
{"relevance_score", json_value(rank, "score", 0.0)},
595596
});
597+
598+
n_tokens += json_value(rank, "tokens_evaluated", 0);
596599
}
597600

598601
json res = json {
599602
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
600603
{"object", "list"},
601-
{"usage", json { // TODO: fill
602-
{"prompt_tokens", 0},
603-
{"total_tokens", 0}
604+
{"usage", json {
605+
{"prompt_tokens", n_tokens},
606+
{"total_tokens", n_tokens}
604607
}},
605608
{"results", data}
606609
};

0 commit comments

Comments
 (0)