@@ -15,7 +15,6 @@ limitations under the License.
1515
1616#include " rerank_service_impl.h"
1717
18- #include < glog/logging.h>
1918#include < torch/torch.h>
2019
2120#include < string>
@@ -28,29 +27,54 @@ limitations under the License.
2827#include " util/uuid.h"
2928
3029namespace xllm {
31- namespace {
32-
33- struct RerankRequestOutput {
34- int32_t index = 0 ;
35- std::string document = " " ;
36- float score = 0 .0f ;
37-
38- RerankRequestOutput (int32_t index, std::string document, float score)
39- : index(index), document(std::move(document)), score(score) {}
40- };
41-
42- bool send_result_to_client_brpc (std::shared_ptr<RerankCall> call,
43- const std::string& request_id,
44- int64_t created_time,
45- const std::string& model,
46- const std::vector<std::string>& documents,
47- int32_t top_n,
48- const std::vector<RequestOutput>& req_outputs) {
49- auto & response = call->response ();
50- response.set_id (request_id);
51- response.set_model (model);
30+ RerankServiceImpl::RerankServiceImpl (LLMMaster* master,
31+ const std::vector<std::string>& models)
32+ : APIServiceImpl(models), master_(master) {
33+ CHECK (master_ != nullptr );
34+ }
35+
36+ // rerank_async for brpc
37+ void RerankServiceImpl::process_async_impl (std::shared_ptr<RerankCall> call) {
38+ const auto & rpc_request = call->request ();
39+ // check if model is supported
40+ const auto & model = rpc_request.model ();
41+ if (!models_.contains (model)) {
42+ call->finish_with_error (StatusCode::UNKNOWN, " Model not supported" );
43+ return ;
44+ }
45+
46+ std::vector<std::string> documents;
47+ if (rpc_request.documents_size () > 0 ) {
48+ documents = std::vector<std::string>(rpc_request.documents ().begin (),
49+ rpc_request.documents ().end ());
50+ }
51+ documents.emplace_back (rpc_request.query ());
52+
53+ // create RequestParams for rerank request
54+ RequestParams request_params (
55+ rpc_request, call->get_x_request_id (), call->get_x_request_time ());
56+ std::vector<RequestParams> sps (documents.size (), request_params);
57+ auto request_id = request_params.request_id ;
58+ auto created_time = absl::ToUnixSeconds (absl::Now ());
59+
60+ // schedule the request
61+ std::vector<RequestOutput> req_outputs;
62+ req_outputs.resize (documents.size ());
63+ BlockingCounter counter (documents.size ());
64+
65+ auto batch_callback = [&req_outputs, &counter](size_t index,
66+ RequestOutput output) -> bool {
67+ req_outputs[index] = std::move (output);
68+ counter.decrement_count ();
69+ return true ;
70+ };
71+
72+ master_->handle_batch_request (documents, sps, batch_callback);
73+
74+ // Wait for all tasks to complete
75+ counter.wait ();
5276
53- // calculate cosine similarity
77+ // calculate cosine similarity to get score
5478 size_t doc_size = documents.size () - 1 ;
5579 std::string query = documents[doc_size];
5680 auto query_embed = req_outputs[doc_size].outputs [0 ].embeddings .value ();
@@ -70,16 +94,41 @@ bool send_result_to_client_brpc(std::shared_ptr<RerankCall> call,
7094 }
7195 }
7296
97+ // send result to client
98+ int32_t top_n = documents.size () - 1 ;
99+ if (rpc_request.has_top_n ()) {
100+ top_n = std::min (top_n, rpc_request.top_n ());
101+ }
102+ send_result_to_client_brpc (call,
103+ request_id,
104+ created_time,
105+ model,
106+ top_n,
107+ rerank_outputs,
108+ req_outputs);
109+ }
110+
111+ bool RerankServiceImpl::send_result_to_client_brpc (
112+ std::shared_ptr<RerankCall> call,
113+ const std::string& request_id,
114+ int64_t created_time,
115+ const std::string& model,
116+ int32_t top_n,
117+ std::vector<RerankRequestOutput>& rerank_outputs,
118+ const std::vector<RequestOutput>& req_outputs) {
119+ auto & response = call->response ();
120+ response.set_id (request_id);
121+ response.set_model (model);
122+
73123 std::sort (rerank_outputs.begin (),
74124 rerank_outputs.end (),
75125 [](const RerankRequestOutput& a, const RerankRequestOutput& b) {
76126 return a.score > b.score ;
77127 });
78128
79129 // add top_n results
80- int32_t valid_top_n = std::min (top_n, static_cast <int32_t >(doc_size));
81- response.mutable_results ()->Reserve (valid_top_n);
82- for (int32_t i = 0 ; i < valid_top_n; ++i) {
130+ response.mutable_results ()->Reserve (top_n);
131+ for (int32_t i = 0 ; i < top_n; ++i) {
83132 auto * result = response.add_results ();
84133 result->set_index (rerank_outputs[i].index );
85134 auto * document = result->mutable_document ();
@@ -109,62 +158,4 @@ bool send_result_to_client_brpc(std::shared_ptr<RerankCall> call,
109158 return call->write_and_finish (response);
110159}
111160
112- } // namespace
113-
114- RerankServiceImpl::RerankServiceImpl (LLMMaster* master,
115- const std::vector<std::string>& models)
116- : APIServiceImpl(models), master_(master) {
117- CHECK (master_ != nullptr );
118- }
119-
120- // rerank_async for brpc
121- void RerankServiceImpl::process_async_impl (std::shared_ptr<RerankCall> call) {
122- const auto & rpc_request = call->request ();
123- // check if model is supported
124- const auto & model = rpc_request.model ();
125- if (!models_.contains (model)) {
126- call->finish_with_error (StatusCode::UNKNOWN, " Model not supported" );
127- return ;
128- }
129-
130- std::vector<std::string> documents;
131- if (rpc_request.documents_size () > 0 ) {
132- documents = std::vector<std::string>(rpc_request.documents ().begin (),
133- rpc_request.documents ().end ());
134- }
135- documents.emplace_back (rpc_request.query ());
136-
137- // create RequestParams for rerank request
138- RequestParams request_params (
139- rpc_request, call->get_x_request_id (), call->get_x_request_time ());
140- std::vector<RequestParams> sps (documents.size (), request_params);
141- auto request_id = request_params.request_id ;
142- auto created_time = absl::ToUnixSeconds (absl::Now ());
143-
144- // schedule the request
145- std::vector<RequestOutput> req_outputs;
146- req_outputs.resize (documents.size ());
147- BlockingCounter counter (documents.size ());
148-
149- auto batch_callback = [&req_outputs, &counter](size_t index,
150- RequestOutput output) -> bool {
151- req_outputs[index] = std::move (output);
152- counter.decrement_count ();
153- return true ;
154- };
155-
156- master_->handle_batch_request (documents, sps, batch_callback);
157-
158- // Wait for all tasks to complete
159- counter.wait ();
160-
161- int32_t top_n = documents.size () - 1 ;
162- if (rpc_request.has_top_n ()) {
163- top_n = rpc_request.top_n ();
164- }
165-
166- send_result_to_client_brpc (
167- call, request_id, created_time, model, documents, top_n, req_outputs);
168- }
169-
170161} // namespace xllm
0 commit comments