Skip to content

Commit 884377f

Browse files
RobbieLeungyq33victor
authored andcommitted
feat: hard code version of qwen3 offical rerank interface.
1 parent b3a0b61 commit 884377f

File tree

12 files changed

+267
-91
lines changed

12 files changed

+267
-91
lines changed

xllm/api_service/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cc_library(
1212
embedding_service_impl.h
1313
image_generation_service_impl.h
1414
rerank_service_impl.h
15+
qwen3_rerank_service_impl.h
1516
non_stream_call.h
1617
service_impl_factory.h
1718
stream_call.h
@@ -25,6 +26,7 @@ cc_library(
2526
image_generation_service_impl.cpp
2627
models_service_impl.cpp
2728
rerank_service_impl.cpp
29+
qwen3_rerank_service_impl.cpp
2830
DEPS
2931
:master
3032
:chat_template

xllm/api_service/api_service.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,15 @@ APIService::APIService(Master* master,
5151
embedding_service_impl_ =
5252
ServiceImplFactory<EmbeddingServiceImpl>::create_service_impl(
5353
llm_master, model_names);
54-
rerank_service_impl_ =
55-
ServiceImplFactory<RerankServiceImpl>::create_service_impl(llm_master,
56-
model_names);
54+
if (FLAGS_enable_qwen3_reranker) {
55+
rerank_service_impl_ =
56+
ServiceImplFactory<Qwen3RerankServiceImpl>::create_service_impl(
57+
llm_master, model_names);
58+
} else {
59+
rerank_service_impl_ =
60+
ServiceImplFactory<RerankServiceImpl>::create_service_impl(
61+
llm_master, model_names);
62+
}
5763
} else if (FLAGS_backend == "vlm") {
5864
auto vlm_master = dynamic_cast<VLMMaster*>(master);
5965
mm_chat_service_impl_ =

xllm/api_service/api_service.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include "embedding_service_impl.h"
2121
#include "image_generation_service_impl.h"
2222
#include "models_service_impl.h"
23+
#include "qwen3_rerank_service_impl.h"
2324
#include "rerank_service_impl.h"
2425
#include "xllm_service.pb.h"
2526

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "api_service/qwen3_rerank_service_impl.h"
17+
18+
#include "util/blocking_counter.h"
19+
20+
namespace xllm {
21+
22+
Qwen3RerankServiceImpl::Qwen3RerankServiceImpl(
23+
LLMMaster* master,
24+
const std::vector<std::string>& models)
25+
: RerankServiceImpl(master, models) {}
26+
27+
void Qwen3RerankServiceImpl::process_async_impl(
28+
std::shared_ptr<RerankCall> call) {
29+
const auto& rpc_request = call->request();
30+
// check if model is supported
31+
const auto& model = rpc_request.model();
32+
if (!models_.contains(model)) {
33+
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
34+
return;
35+
}
36+
37+
auto query = rpc_request.query();
38+
std::vector<std::string> documents;
39+
if (rpc_request.documents_size() > 0) {
40+
documents = std::vector<std::string>(rpc_request.documents().begin(),
41+
rpc_request.documents().end());
42+
}
43+
std::vector<std::string> reqs;
44+
reqs.reserve(documents.size());
45+
for (size_t i = 0; i < documents.size(); ++i) {
46+
reqs.emplace_back(query + documents[i]);
47+
}
48+
49+
// create RequestParams for rerank request
50+
RequestParams request_params(
51+
rpc_request, call->get_x_request_id(), call->get_x_request_time());
52+
std::vector<RequestParams> sps(documents.size(), request_params);
53+
auto request_id = request_params.request_id;
54+
auto created_time = absl::ToUnixSeconds(absl::Now());
55+
56+
// schedule the request
57+
std::vector<RequestOutput> req_outputs;
58+
req_outputs.resize(documents.size());
59+
BlockingCounter counter(documents.size());
60+
61+
auto batch_callback = [&req_outputs, &counter](size_t index,
62+
RequestOutput output) -> bool {
63+
req_outputs[index] = std::move(output);
64+
counter.decrement_count();
65+
return true;
66+
};
67+
68+
master_->handle_batch_request(reqs, sps, batch_callback);
69+
70+
// Wait for all tasks to complete
71+
counter.wait();
72+
73+
// get score
74+
size_t doc_size = documents.size();
75+
std::vector<RerankRequestOutput> rerank_outputs;
76+
rerank_outputs.reserve(doc_size);
77+
for (size_t i = 0; i < doc_size; ++i) {
78+
if (req_outputs[i].outputs[0].logprobs.has_value()) {
79+
auto score = req_outputs[i].outputs[0].logprobs.value()[0].logprob;
80+
rerank_outputs.emplace_back(i, documents[i], score);
81+
}
82+
}
83+
84+
// send result to client
85+
int32_t top_n = documents.size();
86+
if (rpc_request.has_top_n()) {
87+
top_n = std::min(top_n, rpc_request.top_n());
88+
}
89+
send_result_to_client_brpc(call,
90+
request_id,
91+
created_time,
92+
model,
93+
top_n,
94+
rerank_outputs,
95+
req_outputs);
96+
}
97+
98+
} // namespace xllm
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include "api_service/rerank_service_impl.h"
19+
20+
namespace xllm {
21+
using RerankCall = NonStreamCall<proto::RerankRequest, proto::RerankResponse>;
22+
23+
// a class to handle completion requests
24+
class Qwen3RerankServiceImpl final : public RerankServiceImpl {
25+
public:
26+
Qwen3RerankServiceImpl(LLMMaster* master,
27+
const std::vector<std::string>& models);
28+
29+
// brpc call_data needs to use shared_ptr
30+
void process_async_impl(std::shared_ptr<RerankCall> call) override;
31+
32+
private:
33+
DISALLOW_COPY_AND_ASSIGN(Qwen3RerankServiceImpl);
34+
};
35+
} // namespace xllm

xllm/api_service/rerank_service_impl.cpp

Lines changed: 75 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ limitations under the License.
1515

1616
#include "rerank_service_impl.h"
1717

18-
#include <glog/logging.h>
1918
#include <torch/torch.h>
2019

2120
#include <string>
@@ -28,29 +27,54 @@ limitations under the License.
2827
#include "util/uuid.h"
2928

3029
namespace xllm {
31-
namespace {
32-
33-
struct RerankRequestOutput {
34-
int32_t index = 0;
35-
std::string document = "";
36-
float score = 0.0f;
37-
38-
RerankRequestOutput(int32_t index, std::string document, float score)
39-
: index(index), document(std::move(document)), score(score) {}
40-
};
41-
42-
bool send_result_to_client_brpc(std::shared_ptr<RerankCall> call,
43-
const std::string& request_id,
44-
int64_t created_time,
45-
const std::string& model,
46-
const std::vector<std::string>& documents,
47-
int32_t top_n,
48-
const std::vector<RequestOutput>& req_outputs) {
49-
auto& response = call->response();
50-
response.set_id(request_id);
51-
response.set_model(model);
30+
RerankServiceImpl::RerankServiceImpl(LLMMaster* master,
31+
const std::vector<std::string>& models)
32+
: APIServiceImpl(models), master_(master) {
33+
CHECK(master_ != nullptr);
34+
}
35+
36+
// rerank_async for brpc
37+
void RerankServiceImpl::process_async_impl(std::shared_ptr<RerankCall> call) {
38+
const auto& rpc_request = call->request();
39+
// check if model is supported
40+
const auto& model = rpc_request.model();
41+
if (!models_.contains(model)) {
42+
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
43+
return;
44+
}
45+
46+
std::vector<std::string> documents;
47+
if (rpc_request.documents_size() > 0) {
48+
documents = std::vector<std::string>(rpc_request.documents().begin(),
49+
rpc_request.documents().end());
50+
}
51+
documents.emplace_back(rpc_request.query());
52+
53+
// create RequestParams for rerank request
54+
RequestParams request_params(
55+
rpc_request, call->get_x_request_id(), call->get_x_request_time());
56+
std::vector<RequestParams> sps(documents.size(), request_params);
57+
auto request_id = request_params.request_id;
58+
auto created_time = absl::ToUnixSeconds(absl::Now());
59+
60+
// schedule the request
61+
std::vector<RequestOutput> req_outputs;
62+
req_outputs.resize(documents.size());
63+
BlockingCounter counter(documents.size());
64+
65+
auto batch_callback = [&req_outputs, &counter](size_t index,
66+
RequestOutput output) -> bool {
67+
req_outputs[index] = std::move(output);
68+
counter.decrement_count();
69+
return true;
70+
};
71+
72+
master_->handle_batch_request(documents, sps, batch_callback);
73+
74+
// Wait for all tasks to complete
75+
counter.wait();
5276

53-
// calculate cosine similarity
77+
// calculate cosine similarity to get score
5478
size_t doc_size = documents.size() - 1;
5579
std::string query = documents[doc_size];
5680
auto query_embed = req_outputs[doc_size].outputs[0].embeddings.value();
@@ -70,16 +94,41 @@ bool send_result_to_client_brpc(std::shared_ptr<RerankCall> call,
7094
}
7195
}
7296

97+
// send result to client
98+
int32_t top_n = documents.size() - 1;
99+
if (rpc_request.has_top_n()) {
100+
top_n = std::min(top_n, rpc_request.top_n());
101+
}
102+
send_result_to_client_brpc(call,
103+
request_id,
104+
created_time,
105+
model,
106+
top_n,
107+
rerank_outputs,
108+
req_outputs);
109+
}
110+
111+
bool RerankServiceImpl::send_result_to_client_brpc(
112+
std::shared_ptr<RerankCall> call,
113+
const std::string& request_id,
114+
int64_t created_time,
115+
const std::string& model,
116+
int32_t top_n,
117+
std::vector<RerankRequestOutput>& rerank_outputs,
118+
const std::vector<RequestOutput>& req_outputs) {
119+
auto& response = call->response();
120+
response.set_id(request_id);
121+
response.set_model(model);
122+
73123
std::sort(rerank_outputs.begin(),
74124
rerank_outputs.end(),
75125
[](const RerankRequestOutput& a, const RerankRequestOutput& b) {
76126
return a.score > b.score;
77127
});
78128

79129
// add top_n results
80-
int32_t valid_top_n = std::min(top_n, static_cast<int32_t>(doc_size));
81-
response.mutable_results()->Reserve(valid_top_n);
82-
for (int32_t i = 0; i < valid_top_n; ++i) {
130+
response.mutable_results()->Reserve(top_n);
131+
for (int32_t i = 0; i < top_n; ++i) {
83132
auto* result = response.add_results();
84133
result->set_index(rerank_outputs[i].index);
85134
auto* document = result->mutable_document();
@@ -109,62 +158,4 @@ bool send_result_to_client_brpc(std::shared_ptr<RerankCall> call,
109158
return call->write_and_finish(response);
110159
}
111160

112-
} // namespace
113-
114-
RerankServiceImpl::RerankServiceImpl(LLMMaster* master,
115-
const std::vector<std::string>& models)
116-
: APIServiceImpl(models), master_(master) {
117-
CHECK(master_ != nullptr);
118-
}
119-
120-
// rerank_async for brpc
121-
void RerankServiceImpl::process_async_impl(std::shared_ptr<RerankCall> call) {
122-
const auto& rpc_request = call->request();
123-
// check if model is supported
124-
const auto& model = rpc_request.model();
125-
if (!models_.contains(model)) {
126-
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
127-
return;
128-
}
129-
130-
std::vector<std::string> documents;
131-
if (rpc_request.documents_size() > 0) {
132-
documents = std::vector<std::string>(rpc_request.documents().begin(),
133-
rpc_request.documents().end());
134-
}
135-
documents.emplace_back(rpc_request.query());
136-
137-
// create RequestParams for rerank request
138-
RequestParams request_params(
139-
rpc_request, call->get_x_request_id(), call->get_x_request_time());
140-
std::vector<RequestParams> sps(documents.size(), request_params);
141-
auto request_id = request_params.request_id;
142-
auto created_time = absl::ToUnixSeconds(absl::Now());
143-
144-
// schedule the request
145-
std::vector<RequestOutput> req_outputs;
146-
req_outputs.resize(documents.size());
147-
BlockingCounter counter(documents.size());
148-
149-
auto batch_callback = [&req_outputs, &counter](size_t index,
150-
RequestOutput output) -> bool {
151-
req_outputs[index] = std::move(output);
152-
counter.decrement_count();
153-
return true;
154-
};
155-
156-
master_->handle_batch_request(documents, sps, batch_callback);
157-
158-
// Wait for all tasks to complete
159-
counter.wait();
160-
161-
int32_t top_n = documents.size() - 1;
162-
if (rpc_request.has_top_n()) {
163-
top_n = rpc_request.top_n();
164-
}
165-
166-
send_result_to_client_brpc(
167-
call, request_id, created_time, model, documents, top_n, req_outputs);
168-
}
169-
170161
} // namespace xllm

0 commit comments

Comments
 (0)