Skip to content

Commit b41197d

Browse files
michalkulakowskidkalinowskidtrawins
authored
qwen3 embeddings support (#3529)
* Add support for last token pooling mode Co-authored-by: Trawinski, Dariusz <[email protected]> --------- Co-authored-by: Damian Kalinowski <[email protected]> Co-authored-by: Trawinski, Dariusz <[email protected]>
1 parent 6dffee0 commit b41197d

21 files changed

+196
-49
lines changed

demos/common/export_models/export_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def add_common_arguments(parser):
6565
parser_embeddings_ov = subparsers.add_parser('embeddings_ov', help='export model for embeddings endpoint with directory structure aligned with OpenVINO tools')
6666
add_common_arguments(parser_embeddings_ov)
6767
parser_embeddings_ov.add_argument('--skip_normalize', default=True, action='store_false', help='Skip normalize the embeddings.', dest='normalize')
68+
parser_embeddings_ov.add_argument('--pooling', default="CLS", choices=["CLS", "LAST"], help='Embeddings pooling mode', dest='pooling')
6869
parser_embeddings_ov.add_argument('--truncate', default=False, action='store_true', help='Truncate the prompts to fit to the embeddings model', dest='truncate')
6970
parser_embeddings_ov.add_argument('--num_streams', default=1,type=int, help='The number of parallel execution streams to use for the model. Use at least 2 on 2 socket CPU systems.', dest='num_streams')
7071

@@ -139,6 +140,8 @@ def add_common_arguments(parser):
139140
[type.googleapis.com / mediapipe.EmbeddingsCalculatorOVOptions]: {
140141
models_path: "{{model_path}}",
141142
normalize_embeddings: {% if not normalize %}false{% else %}true{% endif%},
143+
{%- if pooling %}
144+
pooling: {{pooling}},{% endif %}
142145
target_device: "{{target_device|default("CPU", true)}}"
143146
}
144147
}

demos/embeddings/README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,16 @@ models
5757
└── config.json
5858
5959
```
60-
> **Note** The actual models support version management and can be automatically swapped to newer version when new model is uploaded in newer version folder.
61-
> In case you trained the pytorch model it can be exported like below:
62-
> `python export_model.py embeddings_ov --source_model <pytorch model> --model_name Alibaba-NLP/gte-large-en-v1.5 --precision int8 --config_file_path models/config.json --version 2`
6360

64-
The default configuration of the `EmbeddingsCalculator` should work in most cases but the parameters can be tuned inside the `node_options` section in the `graph.pbtxt` file. Runtime configuration for both models can be tuned in `subconfig.json` file. They can be set automatically via export parameters in the `export_model.py` script.
61+
The default configuration of the `EmbeddingsCalculatorOV` should work in most cases but the parameters can be tuned inside the `node_options` section in the `graph.pbtxt` file. They can be set automatically via export parameters in the `export_model.py` script.
6562

6663
For example:
67-
`python export_model.py embeddings_ov --source_model Alibaba-NLP/gte-large-en-v1.5 --precision int8 --num_streams 2 --skip_normalize --config_file_path models/config.json`
64+
`python export_model.py embeddings_ov --source_model Alibaba-NLP/gte-large-en-v1.5 --weight-format int8 --skip_normalize --config_file_path models/config.json`
6865

66+
> **Note:** By default OVMS returns first token embeddings as sequence embeddings (called CLS pooling). It can be changed using `--pooling` option if needed by the model. Supported values are CLS and LAST. For example:
67+
```console
68+
python export_model.py embeddings_ov --source_model Qwen/Qwen3-Embedding-0.6B --weight-format fp16 --pooling LAST --config_file_path models/config.json`
69+
```
6970

7071
## Tested models
7172
All models supported by [optimum-intel](https://github.com/huggingface/optimum-intel) should be compatible. In serving validation are included Hugging Face models:
@@ -240,7 +241,7 @@ The script [compare_results.py](./compare_results.py) can assist with such exper
240241
```bash
241242
popd
242243
cd model_server/demos/embeddings
243-
python compare_results.py --model Alibaba-NLP/gte-large-en-v1.5 --service_url http://localhost:8000/v3/embeddings --input "hello world" --input "goodbye world"
244+
python compare_results.py --model Alibaba-NLP/gte-large-en-v1.5 --service_url http://localhost:8000/v3/embeddings --pooling CLS --input "hello world" --input "goodbye world"
244245

245246
input ['hello world', 'goodbye world']
246247
HF Duration: 50.626 ms NewModel

demos/embeddings/compare_results.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
dest='model_name')
2929
parser.add_argument('--input', default=[], help='List of strings to query. default: []',
3030
dest='input', action='append')
31+
parser.add_argument('--pooling', default="CLS", choices=["CLS", "LAST"], help='Embeddings pooling mode', dest='pooling')
32+
3133
args = vars(parser.parse_args())
3234

3335
model_id = args['model_name']
@@ -43,8 +45,14 @@ def run_model():
4345
start_time = datetime.datetime.now()
4446
input = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
4547
model_output = model_pt(**input)
46-
embeddings = model_output.last_hidden_state[:, 0]
47-
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
48+
if args['pooling'] == "LAST":
49+
sequence_lengths = input['attention_mask'].sum(dim=1) - 1
50+
batch_size = model_output.last_hidden_state.shape[0]
51+
embeddings = model_output.last_hidden_state[torch.arange(batch_size, device=model_output.last_hidden_state.device), sequence_lengths]
52+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
53+
else:
54+
embeddings = model_output.last_hidden_state[:, 0]
55+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
4856
end_time = datetime.datetime.now()
4957
duration = (end_time - start_time).total_seconds() * 1000
5058
print("HF Duration:", duration, "ms", type(model_pt).__name__)

src/capi_frontend/server_settings.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ struct EmbeddingsGraphSettingsImpl {
114114
std::string modelName = "";
115115
uint32_t numStreams = 1;
116116
std::string normalize = "true";
117-
std::string meanPooling = "false";
117+
std::string pooling = "CLS";
118118
};
119119

120120
struct RerankGraphSettingsImpl {

src/embeddings/BUILD

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ ovms_cc_library(
2929
alwayslink = 1,
3030
)
3131

32+
ovms_cc_library(
33+
name = "embeddings_servable",
34+
hdrs = ["embeddings_servable.hpp"],
35+
deps = ["//src:sidepacket_servable",],
36+
visibility = ["//visibility:public"],
37+
alwayslink = 1,
38+
)
39+
3240
mediapipe_proto_library(
3341
name = "embeddings_calculator_proto", # embeddings_calculator_cc_proto - just mediapipe stuff with mediapipe_proto_library adding nonvisible target
3442
srcs = ["embeddings_calculator.proto"],
@@ -81,11 +89,12 @@ ovms_cc_library(
8189
"//src:libovmslogging",
8290
"//src:libovmsprofiler",
8391
"embeddings_calculator_ov_cc_proto",
84-
":embeddings_api",
92+
":embeddings_servable",
8593
"//src:sidepacket_servable",
8694
"//src:model_metric_reporter",
8795
"//src:executingstreamidguard",
8896
"//src:libovms_execution_context",
97+
":embeddings_api",
8998
],
9099
visibility = ["//visibility:public"],
91100
alwayslink = 1,

src/embeddings/embeddings_api.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ void EmbeddingsHandler::setPromptTokensUsage(int promptTokens) {
162162

163163
#pragma warning(push)
164164
#pragma warning(disable : 4267)
165-
absl::Status EmbeddingsHandler::parseResponse(StringBuffer& buffer, const ov::Tensor& embeddingsTensor, const bool normalizeEmbeddings) {
165+
absl::Status EmbeddingsHandler::parseResponse(StringBuffer& buffer, const ov::Tensor& embeddingsTensor, const bool normalizeEmbeddings, const PoolingMode poolingMode, const std::optional<ov::Tensor>& attentionMask) {
166166
Writer<StringBuffer> writer(buffer);
167167
writer.StartObject();
168168

@@ -171,15 +171,42 @@ absl::Status EmbeddingsHandler::parseResponse(StringBuffer& buffer, const ov::Te
171171

172172
writer.String("data");
173173
writer.StartArray();
174-
// TODO: mean pooling
175174

176175
ov::Shape outputShape = embeddingsTensor.get_shape();
177176
if (outputShape.size() != 3) {
178177
return absl::InvalidArgumentError("Invalid embeddings tensor shape");
179178
}
180179
size_t batchSize = outputShape[0];
181180
for (size_t batchIterator = 0; batchIterator < batchSize; batchIterator++) {
182-
size_t stride = batchIterator * outputShape[1] * outputShape[2];
181+
size_t stride;
182+
if (poolingMode == PoolingMode::LAST) {
183+
size_t attendedTokens = 0;
184+
if (!attentionMask.has_value()) {
185+
return absl::InvalidArgumentError("Last token pooling mode requires attention mask");
186+
}
187+
auto maxNumberOfTokens = attentionMask->get_shape()[1];
188+
if (attentionMask->get_element_type() == ov::element::Type_t::i64) {
189+
for (int i = 0; i < maxNumberOfTokens; i++) {
190+
attendedTokens += reinterpret_cast<int64_t*>(attentionMask->data())[i + batchIterator * maxNumberOfTokens];
191+
}
192+
} else if (attentionMask->get_element_type() == ov::element::Type_t::i32) {
193+
for (int i = 0; i < maxNumberOfTokens; i++) {
194+
attendedTokens += reinterpret_cast<int32_t*>(attentionMask->data())[i + batchIterator * maxNumberOfTokens];
195+
}
196+
} else if (attentionMask->get_element_type() == ov::element::Type_t::i8) {
197+
for (int i = 0; i < maxNumberOfTokens; i++) {
198+
attendedTokens += reinterpret_cast<uint8_t*>(attentionMask->data())[i + batchIterator * maxNumberOfTokens];
199+
}
200+
} else {
201+
return absl::InternalError("Attention mask element type invalid.");
202+
}
203+
if (!(attendedTokens <= outputShape[1])) {
204+
return absl::InternalError("Embeddings output and attention mask shape mismatch");
205+
}
206+
stride = batchIterator * outputShape[1] * outputShape[2] + (attendedTokens - 1) * outputShape[2];
207+
} else {
208+
stride = batchIterator * outputShape[1] * outputShape[2];
209+
}
183210
size_t size = outputShape[2];
184211
float* dataPtr = reinterpret_cast<float*>(embeddingsTensor.data()) + stride;
185212
float* dataPtrEnd = dataPtr + size;

src/embeddings/embeddings_api.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636

3737
namespace ovms {
3838

39+
enum class PoolingMode {
40+
CLS,
41+
LAST
42+
};
43+
3944
struct EmbeddingsRequest {
4045
enum class EncodingFormat {
4146
FLOAT,
@@ -60,7 +65,7 @@ class EmbeddingsHandler {
6065
EmbeddingsRequest::EncodingFormat getEncodingFormat() const;
6166

6267
absl::Status parseRequest();
63-
absl::Status parseResponse(rapidjson::StringBuffer& buffer, const ov::Tensor& embeddingsTensor, const bool normalizeEmbeddings);
68+
absl::Status parseResponse(rapidjson::StringBuffer& buffer, const ov::Tensor& embeddingsTensor, const bool normalizeEmbeddings, const PoolingMode poolingMode = PoolingMode::CLS, const std::optional<ov::Tensor>& attentionMask = std::nullopt);
6469
void setPromptTokensUsage(int promptTokens);
6570
};
6671
} // namespace ovms

src/embeddings/embeddings_calculator.proto

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,4 @@ message EmbeddingsCalculatorOptions {
2626
optional EmbeddingsCalculatorOptions ext = 1134737;
2727
}
2828
optional bool normalize_embeddings = 1 [default = true];
29-
optional bool mean_pooling = 2 [default = false];
3029
}

src/embeddings/embeddings_calculator_ov.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
#include "../model_metric_reporter.hpp"
4141
#include "embeddings_api.hpp"
4242
#include "src/embeddings/embeddings_calculator_ov.pb.h"
43-
#include "../sidepacket_servable.hpp"
43+
#include "embeddings_servable.hpp"
4444

4545
using namespace rapidjson;
4646
using namespace ovms;
@@ -63,7 +63,7 @@ class EmbeddingsCalculatorOV : public CalculatorBase {
6363
mediapipe::Timestamp timestamp{0};
6464

6565
protected:
66-
std::shared_ptr<ovms::SidepacketServable> embeddings_session{nullptr};
66+
std::shared_ptr<ovms::EmbeddingsServable> embeddings_session{nullptr};
6767

6868
public:
6969
static absl::Status GetContract(CalculatorContract* cc) {
@@ -148,10 +148,12 @@ class EmbeddingsCalculatorOV : public CalculatorBase {
148148
for (int i = 0; i < tokens.attention_mask.get_size(); i++) {
149149
attendedTokens += reinterpret_cast<int32_t*>(tokens.attention_mask.data())[i];
150150
}
151-
} else {
151+
} else if (tokens.attention_mask.get_element_type() == ov::element::Type_t::i8) {
152152
for (int i = 0; i < tokens.attention_mask.get_byte_size(); i++) {
153153
attendedTokens += reinterpret_cast<uint8_t*>(tokens.attention_mask.data())[i];
154154
}
155+
} else {
156+
return absl::InternalError("Attention mask element type invalid.");
155157
}
156158
handler.setPromptTokensUsage(attendedTokens);
157159
} else if (auto tokenized_documents = std::get_if<std::vector<std::vector<int64_t>>>(&input)) {
@@ -241,7 +243,13 @@ class EmbeddingsCalculatorOV : public CalculatorBase {
241243

242244
auto parseResponseStartTime = std::chrono::high_resolution_clock::now();
243245
StringBuffer buffer;
244-
status = handler.parseResponse(buffer, embeddingsTensor, cc->Options<EmbeddingsCalculatorOVOptions>().normalize_embeddings());
246+
PoolingMode mode;
247+
if (cc->Options<EmbeddingsCalculatorOVOptions>().pooling() == mediapipe::EmbeddingsCalculatorOVOptions::LAST) {
248+
mode = PoolingMode::LAST;
249+
} else {
250+
mode = PoolingMode::CLS;
251+
}
252+
status = handler.parseResponse(buffer, embeddingsTensor, cc->Options<EmbeddingsCalculatorOVOptions>().normalize_embeddings(), mode, tokens.attention_mask);
245253
if (!status.ok()) {
246254
return status;
247255
}

src/embeddings/embeddings_calculator_ov.proto

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ message EmbeddingsCalculatorOVOptions {
2727
}
2828
required string models_path = 1;
2929
optional bool normalize_embeddings = 2 [default = true];
30-
optional bool mean_pooling = 3 [default = false];
31-
optional string target_device = 4 [default = "CPU"];
32-
optional string plugin_config = 5 [default = ""];
30+
optional string target_device = 3 [default = "CPU"];
31+
optional string plugin_config = 4 [default = ""];
32+
enum Pooling {
33+
CLS = 0;
34+
LAST = 1;
35+
}
36+
optional Pooling pooling = 5 [default = CLS];
3337
}

0 commit comments

Comments
 (0)