Skip to content

Commit 2e65fae

Browse files
Workaround support for qwen3 rerank (#3578)
1 parent 285bb13 commit 2e65fae

File tree

7 files changed

+143
-11
lines changed

7 files changed

+143
-11
lines changed

demos/rerank/README.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,57 @@ index 1, relevance_score 0.09138210117816925
144144
```
145145
:::
146146

147+
:::{dropdown} **Requesting rerank score with model that requires template applying on query and documents**
148+
149+
tomaarsen/Qwen3-Reranker-0.6B-seq-cls is a copy of the Qwen3-Reranker-0.6B model (original model is not supported in OVMS) modified as a sequence classification model instead. It requires applying template on input, here is example client that does it:
150+
151+
```bash
152+
pip3 install requests
153+
```
154+
```bash
155+
echo '
156+
import requests
157+
158+
prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
159+
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
160+
161+
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
162+
document_template = "<Document>: {doc}{suffix}"
163+
164+
instruction = (
165+
"Given a web search query, retrieve relevant passages that answer the query"
166+
)
167+
168+
query = "welcome"
169+
170+
documents = [
171+
"good morning",
172+
"farewell",
173+
]
174+
175+
query = query_template.format(prefix=prefix, instruction=instruction, query=query)
176+
177+
documents = [
178+
document_template.format(doc=doc, suffix=suffix) for doc in documents
179+
]
180+
181+
response = requests.post("http://127.0.0.1:8125/v3/rerank",
182+
json={
183+
"model": "tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
184+
"query": query,
185+
"documents": documents,
186+
}).json()
187+
188+
print(response)' > rerank_client.py
189+
190+
python rerank_client.py
191+
```
192+
It will return response similar to:
193+
```
194+
{'results': [{'index': 0, 'relevance_score': 0.024518223479390144}, {'index': 1, 'relevance_score': 0.0026006349362432957}]}
195+
```
196+
:::
197+
147198
## Comparison with Hugging Faces
148199

149200
```bash
@@ -202,6 +253,7 @@ BAAI/bge-reranker-large
202253
BAAI/bge-reranker-v2-m3
203254
BAAI/bge-reranker-base
204255
cross-encoder/msmarco-MiniLM-L6-en-de-v1
256+
tomaarsen/Qwen3-Reranker-0.6B-seq-cls
205257
```
206258

207259
## Integration with Langchain

src/mediapipe_internal/mediapipegraphdefinition.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,8 @@ Status MediapipeGraphDefinition::initializeNodes() {
550550
}
551551
mediapipe::RerankCalculatorOVOptions nodeOptions;
552552
config.node(i).node_options(0).UnpackTo(&nodeOptions);
553-
std::shared_ptr<SidepacketServable> servable = std::make_shared<SidepacketServable>(nodeOptions.models_path(), nodeOptions.target_device(), nodeOptions.plugin_config(), mgconfig.getBasePath());
554-
rerankServableMap.insert(std::pair<std::string, std::shared_ptr<SidepacketServable>>(nodeName, std::move(servable)));
553+
std::shared_ptr<RerankServable> servable = std::make_shared<RerankServable>(nodeOptions.models_path(), nodeOptions.target_device(), nodeOptions.plugin_config(), mgconfig.getBasePath());
554+
rerankServableMap.insert(std::pair<std::string, std::shared_ptr<RerankServable>>(nodeName, std::move(servable)));
555555
rerankServablesCleaningGuard.disableCleaning();
556556
}
557557
}

src/mediapipe_internal/mediapipegraphdefinition.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
#include "../sidepacket_servable.hpp"
4747
#include "../embeddings/embeddings_servable.hpp"
48+
#include "../rerank/rerank_servable.hpp"
4849

4950
namespace ovms {
5051
class MediapipeGraphDefinitionUnloadGuard;
@@ -60,8 +61,8 @@ class GenAiServable;
6061
struct ImageGenerationPipelines;
6162
using PythonNodeResourcesMap = std::unordered_map<std::string, std::shared_ptr<PythonNodeResources>>;
6263
using GenAiServableMap = std::unordered_map<std::string, std::shared_ptr<GenAiServable>>;
64+
using RerankServableMap = std::unordered_map<std::string, std::shared_ptr<RerankServable>>;
6365
using EmbeddingsServableMap = std::unordered_map<std::string, std::shared_ptr<EmbeddingsServable>>;
64-
using RerankServableMap = std::unordered_map<std::string, std::shared_ptr<SidepacketServable>>;
6566
using ImageGenerationPipelinesMap = std::unordered_map<std::string, std::shared_ptr<ImageGenerationPipelines>>;
6667

6768
struct GraphSidePackets {

src/rerank/BUILD

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ mediapipe_proto_library(
2727
],
2828
)
2929

30+
ovms_cc_library(
31+
name = "rerank_servable",
32+
hdrs = ["rerank_servable.hpp"],
33+
deps = ["//src:sidepacket_servable",],
34+
visibility = ["//visibility:public"],
35+
alwayslink = 1,
36+
)
37+
3038
mediapipe_proto_library(
3139
name = "rerank_calculator_ov_proto", # rerank_calculator_cc_proto - just mediapipe stuff with mediapipe_proto_library adding nonvisible target
3240
srcs = ["rerank_calculator_ov.proto"],
@@ -68,7 +76,7 @@ ovms_cc_library(
6876
"//src:libovmsprofiler",
6977
"rerank_calculator_ov_cc_proto",
7078
":rerank_api_handler",
71-
"//src:sidepacket_servable",
79+
":rerank_servable",
7280
"//src:model_metric_reporter",
7381
"//src:executingstreamidguard",
7482
"//src:libovms_execution_context",

src/rerank/rerank_calculator_ov.cc

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
#include "../profiler.hpp"
4343
#include "src/rerank/rerank_calculator_ov.pb.h"
4444
#include "src/rerank/rerank_utils.hpp"
45-
#include "../sidepacket_servable.hpp"
45+
#include "rerank_servable.hpp"
4646
#include "../model_metric_reporter.hpp"
4747
#include "../executingstreamidguard.hpp"
4848

@@ -77,7 +77,7 @@ class RerankCalculatorOV : public CalculatorBase {
7777
size_t max_allowed_chunks{0}; // Read from options in ::Open()
7878

7979
protected:
80-
std::shared_ptr<ovms::SidepacketServable> rerank_session{nullptr};
80+
std::shared_ptr<ovms::RerankServable> rerank_session{nullptr};
8181

8282
public:
8383
static absl::Status GetContract(CalculatorContract* cc) {
@@ -127,7 +127,7 @@ class RerankCalculatorOV : public CalculatorBase {
127127
}
128128

129129
// post-validation
130-
if (this->max_position_embeddings <= 2 * NUMBER_OF_SPECIAL_TOKENS) {
130+
if (rerank_session->addBosToken && (this->max_position_embeddings <= 2 * NUMBER_OF_SPECIAL_TOKENS)) {
131131
SPDLOG_LOGGER_ERROR(rerank_calculator_logger, "max_position_embeddings should be larger than 2 * NUMBER_OF_SPECIAL_TOKENS");
132132
return absl::InvalidArgumentError("max_position_embeddings should be larger than 2 * NUMBER_OF_SPECIAL_TOKENS");
133133
}
@@ -153,7 +153,25 @@ class RerankCalculatorOV : public CalculatorBase {
153153
// Validate batch size before tokenizing
154154
if (handler.getDocumentsList().size() > this->max_allowed_chunks)
155155
throw std::runtime_error("Number of documents exceeds max_allowed_chunks");
156-
156+
if (!rerank_session->addBosToken) {
157+
auto batchSize = handler.getDocumentsList().size();
158+
std::vector<std::string> data(batchSize);
159+
for (int i = 0; i < batchSize; i++) {
160+
data[i] += handler.getQuery() + handler.getDocumentsList()[i];
161+
}
162+
chunk_mapping.resize(batchSize);
163+
std::iota(chunk_mapping.begin(), chunk_mapping.end(), 0);
164+
auto tokens = rerank_session->getTokenizer().encode(data);
165+
if (tokens.input_ids.get_shape().size() != 2) {
166+
throw std::runtime_error("Tokens shape invalid."); // should never happen
167+
}
168+
if (this->max_position_embeddings < tokens.input_ids.get_shape()[1]) {
169+
std::ostringstream msg;
170+
msg << "The requests length of " << tokens.input_ids.get_shape()[1] << " tokens exceeds the model context of " << max_position_embeddings;
171+
throw std::runtime_error(msg.str());
172+
}
173+
return std::make_pair(tokens.input_ids, tokens.attention_mask);
174+
}
157175
// Compute Query Tokens
158176
auto query_tokens = ComputeTokensForString(handler.getQuery());
159177

@@ -289,8 +307,8 @@ class RerankCalculatorOV : public CalculatorBase {
289307
typeIds = ov::Tensor{ov::element::i64, input_ids.get_shape()};
290308
std::fill_n(typeIds->data<int64_t>(), input_ids.get_size(), 0);
291309
}
292-
// Compute scores using rerank model
293310
size_t batch_size = handler.getDocumentsList().size();
311+
// Compute scores using rerank model
294312
auto scores = ComputeScoresUsingRerankModel(
295313
input_ids,
296314
attention_mask,

src/rerank/rerank_servable.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//*****************************************************************************
2+
// Copyright 2025 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
#pragma once
17+
18+
#include "../sidepacket_servable.hpp"
19+
#include "../filesystem.hpp"
20+
#include <rapidjson/istreamwrapper.h>
21+
#include <rapidjson/error/en.h>
22+
#include <memory>
23+
#include <string>
24+
#include <unordered_map>
25+
26+
namespace ovms {
27+
28+
struct RerankServable : SidepacketServable {
29+
bool addBosToken = true;
30+
RerankServable(const std::string& modelDir, const std::string& targetDevice, const std::string& pluginConfig, const std::string& graphPath) :
31+
SidepacketServable(modelDir, targetDevice, pluginConfig, graphPath) {
32+
std::filesystem::path tokenizerConfigPath = (parsedModelsPath / "tokenizer_config.json");
33+
if (!std::filesystem::exists(tokenizerConfigPath)) {
34+
return;
35+
}
36+
std::ifstream ifs(tokenizerConfigPath.string());
37+
if (!ifs.is_open()) {
38+
return;
39+
}
40+
rapidjson::Document tokenizerConfig;
41+
rapidjson::IStreamWrapper isw(ifs);
42+
rapidjson::ParseResult parseResult = tokenizerConfig.ParseStream(isw);
43+
if (parseResult.Code()) {
44+
SPDLOG_ERROR("Parsing tokenizer_config.json failed: {}", rapidjson::GetParseError_En(parseResult.Code()));
45+
return;
46+
}
47+
if (tokenizerConfig.HasMember("add_bos_token") && tokenizerConfig["add_bos_token"].IsBool() && tokenizerConfig["add_bos_token"].IsFalse()) {
48+
SPDLOG_DEBUG("Rerank model add_bos_token set to false");
49+
addBosToken = false;
50+
}
51+
}
52+
};
53+
54+
using RerankServableMap = std::unordered_map<std::string, std::shared_ptr<RerankServable>>;
55+
} // namespace ovms

src/sidepacket_servable.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,4 @@ struct SidepacketServable {
7676
return compiledModel.inputs().size();
7777
}
7878
};
79-
80-
using RerankServableMap = std::unordered_map<std::string, std::shared_ptr<SidepacketServable>>;
8179
} // namespace ovms

0 commit comments

Comments
 (0)