@@ -849,46 +849,47 @@ static json format_response_rerank(
849849 const json & request,
850850 const json & ranks,
851851 bool is_tei_format,
852- std::vector<std::string> & texts) {
853- json res;
854- if (is_tei_format) {
855- // TEI response format
856- res = json::array ();
857- bool return_text = json_value (request, " return_text" , false );
858- for (const auto & rank : ranks) {
859- int index = json_value (rank, " index" , 0 );
860- json elem = json{
861- {" index" , index},
862- {" score" , json_value (rank, " score" , 0.0 )},
863- };
864- if (return_text) {
865- elem[" text" ] = std::move (texts[index]);
866- }
867- res.push_back (elem);
868- }
869- } else {
870- // Jina response format
871- json results = json::array ();
872- int32_t n_tokens = 0 ;
873- for (const auto & rank : ranks) {
874- results.push_back (json{
875- {" index" , json_value (rank, " index" , 0 )},
876- {" relevance_score" , json_value (rank, " score" , 0.0 )},
877- });
878-
879- n_tokens += json_value (rank, " tokens_evaluated" , 0 );
852+ std::vector<std::string> & texts,
853+ int top_n) {
854+ json results;
855+ int32_t n_tokens = 0 ;
856+ bool return_text = is_tei_format && json_value (request, " return_text" , false );
857+ std::vector<json> elements; // Temporary vector to hold unsorted elements
858+ std::string score_label = is_tei_format ? " score" : " relevance_score" ;
859+ for (const auto & rank : ranks) {
860+ int index = json_value (rank, " index" , 0 );
861+ json elem = json{
862+ {" index" , index},
863+ {score_label, json_value (rank, " score" , 0.0 )},
864+ };
865+ n_tokens += json_value (rank, " tokens_evaluated" , 0 );
866+ if (return_text) {
867+ elem[" text" ] = std::move (texts[index]);
880868 }
869+ elements.push_back (elem);
870+ }
871+
872+ std::sort (elements.begin (), elements.end (), [score_label](const json& a, const json& b) {
873+ return json_value (a, score_label, 0.0 ) > json_value (b, score_label, 0.0 );
874+ });
881875
882- res = json{
883- {" model" , json_value (request, " model" , std::string (DEFAULT_OAICOMPAT_MODEL))},
884- {" object" , " list" },
885- {" usage" , json{
886- {" prompt_tokens" , n_tokens},
887- {" total_tokens" , n_tokens}
888- }},
889- {" results" , results}
890- };
876+ results = json::array ();
877+ int count = 0 ;
878+ for (const auto & elem : elements) {
879+ if (++count > top_n) break ;
880+ results.push_back (elem);
891881 }
882+ if (is_tei_format) return results;
883+
884+ json res = json{
885+ {" model" , json_value (request, " model" , std::string (DEFAULT_OAICOMPAT_MODEL))},
886+ {" object" , " list" },
887+ {" usage" , json{
888+ {" prompt_tokens" , n_tokens},
889+ {" total_tokens" , n_tokens}
890+ }},
891+ {" results" , results}
892+ };
892893
893894 return res;
894895}
0 commit comments