Skip to content

Commit 285bb13

Browse files
Fix truncation in embeddings_ov calculator (#3581)
1 parent 5bcd8a6 commit 285bb13

File tree

10 files changed

+34
-7
lines changed

10 files changed

+34
-7
lines changed

demos/common/export_models/export_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def add_common_arguments(parser):
142142
normalize_embeddings: {% if not normalize %}false{% else %}true{% endif%},
143143
{%- if pooling %}
144144
pooling: {{pooling}},{% endif %}
145+
{%- if truncate %}
146+
truncate: true,{% endif %}
145147
target_device: "{{target_device|default("CPU", true)}}"
146148
}
147149
}
@@ -542,10 +544,6 @@ def export_embeddings_model_ov(model_repository_path, source_model, model_name,
542544
optimum_command = "optimum-cli export openvino --model {} --disable-convert-tokenizer --task feature-extraction --weight-format {} {} --trust-remote-code --library sentence_transformers {}".format(source_model, precision, task_parameters['extra_quantization_params'], destination_path)
543545
if os.system(optimum_command):
544546
raise ValueError("Failed to export embeddings model", source_model)
545-
if truncate:
546-
max_context_length = get_models_max_context(destination_path, 'config.json')
547-
if max_context_length is not None:
548-
set_max_context_length = "--max_length " + str(get_models_max_context(destination_path, 'config.json'))
549547
print("Exporting tokenizer to ", destination_path)
550548
convert_tokenizer_command = "convert_tokenizer -o {} {} {}".format(destination_path, source_model, set_max_context_length)
551549
if (os.system(convert_tokenizer_command)):

docs/parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ Task specific parameters for different tasks (text generation/image generation/e
148148
|---------------------------|--------------|--------------------------------------------------------------------------------|
149149
| `--num_streams` | `integer` | The number of parallel execution streams to use for the model. Use at least 2 on 2 socket CPU systems. Default: 1. |
150150
| `--normalize` | `bool` | Normalize the embeddings. Default: true. |
151+
| `--truncate` | `bool` | Truncate input when it exceeds model context length. Default: false |
151152
| `--mean_pooling` | `bool` | Mean pooling option. Default: false. |
152153

153154
### Rerank

src/capi_frontend/server_settings.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ struct EmbeddingsGraphSettingsImpl {
114114
std::string modelName = "";
115115
uint32_t numStreams = 1;
116116
std::string normalize = "true";
117+
std::string truncate = "false";
117118
std::string pooling = "CLS";
118119
};
119120

src/config.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ bool Config::validate() {
174174
std::cerr << "normalize: " << settings.normalize << " is not allowed. Supported values: true, false" << std::endl;
175175
return false;
176176
}
177+
178+
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.truncate) == allowedBoolValues.end()) {
179+
std::cerr << "truncate: " << settings.truncate << " is not allowed. Supported values: true, false" << std::endl;
180+
return false;
181+
}
177182
}
178183
// No more validation needed
179184
if (this->serverSettings.serverMode == HF_PULL_MODE) {

src/embeddings/embeddings_calculator_ov.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ class EmbeddingsCalculatorOV : public CalculatorBase {
128128
auto input = handler.getInput();
129129
if (auto strings = std::get_if<std::vector<std::string>>(&input)) {
130130
received_batch_size = strings->size();
131-
tokens = embeddings_session->getTokenizer().encode(*strings);
131+
ov::AnyMap params = {};
132+
if (cc->Options<EmbeddingsCalculatorOVOptions>().truncate()) {
133+
params = {{"max_length", max_context_length}};
134+
}
135+
tokens = embeddings_session->getTokenizer().encode(*strings, params);
132136
RET_CHECK(tokens.input_ids.get_shape().size() == 2);
133137
size_t input_ids_size = tokens.input_ids.get_shape()[1];
134138
if (input_ids_size > max_context_length) {

src/embeddings/embeddings_calculator_ov.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ message EmbeddingsCalculatorOVOptions {
3434
LAST = 1;
3535
}
3636
optional Pooling pooling = 5 [default = CLS];
37+
optional bool truncate = 6 [default = false];
3738
}

src/graph_export/embeddings_graph_cli_parser.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ void EmbeddingsGraphCLIParser::createOptions() {
4848
"Normalize the embeddings.",
4949
cxxopts::value<std::string>()->default_value("true"),
5050
"NORMALIZE")
51+
("truncate",
52+
"Truncate input when it exceeds model context length.",
53+
cxxopts::value<std::string>()->default_value("false"),
54+
"truncate")
5155
("pooling",
5256
"Mean pooling option.",
5357
cxxopts::value<std::string>()->default_value("CLS"),
@@ -91,6 +95,7 @@ void EmbeddingsGraphCLIParser::prepare(OvmsServerMode serverMode, HFSettingsImpl
9195
} else {
9296
embeddingsGraphSettings.numStreams = result->operator[]("num_streams").as<uint32_t>();
9397
embeddingsGraphSettings.normalize = result->operator[]("normalize").as<std::string>();
98+
embeddingsGraphSettings.truncate = result->operator[]("truncate").as<std::string>();
9499
embeddingsGraphSettings.pooling = result->operator[]("pooling").as<std::string>();
95100
}
96101
if (!(embeddingsGraphSettings.pooling == "CLS" || embeddingsGraphSettings.pooling == "LAST")){

src/graph_export/graph_export.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ node {
214214
<< graphOkPath << R"(",
215215
normalize_embeddings: )"
216216
<< graphSettings.normalize << R"(,
217+
truncate: )"
218+
<< graphSettings.truncate << R"(,
217219
pooling: )"
218220
<< graphSettings.pooling << R"(,
219221
target_device: ")" << graphSettings.targetDevice << R"(",

src/test/graph_export_test.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ node {
234234
[type.googleapis.com / mediapipe.EmbeddingsCalculatorOVOptions]: {
235235
models_path: "/model1/path",
236236
normalize_embeddings: false,
237+
truncate: true,
237238
pooling: LAST,
238239
target_device: "GPU",
239240
plugin_config: '{ "NUM_STREAMS": "2"}',
@@ -255,6 +256,7 @@ node {
255256
[type.googleapis.com / mediapipe.EmbeddingsCalculatorOVOptions]: {
256257
models_path: "./",
257258
normalize_embeddings: true,
259+
truncate: false,
258260
pooling: CLS,
259261
target_device: "CPU",
260262
plugin_config: '{ "NUM_STREAMS": "1"}',
@@ -458,6 +460,7 @@ TEST_F(GraphCreationTest, embeddingsPositiveNonDefault) {
458460
embeddingsGraphSettings.modelPath = "/model1/path";
459461
embeddingsGraphSettings.numStreams = 2;
460462
embeddingsGraphSettings.normalize = "false";
463+
embeddingsGraphSettings.truncate = "true";
461464
embeddingsGraphSettings.pooling = "LAST";
462465
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
463466
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";

src/test/ovmsconfig_test.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,13 +1686,15 @@ TEST(OvmsGraphConfigTest, positiveAllChangedEmbeddings) {
16861686
(char*)"GPU",
16871687
(char*)"--normalize",
16881688
(char*)"false",
1689+
(char*)"--truncate",
1690+
(char*)"true",
16891691
(char*)"--num_streams",
16901692
(char*)"2",
16911693
(char*)"--model_name",
16921694
(char*)servingName.c_str(),
16931695
};
16941696

1695-
int arg_count = 18;
1697+
int arg_count = 20;
16961698
ConstructorEnabledConfig config;
16971699
config.parse(arg_count, n_argv);
16981700

@@ -1703,6 +1705,7 @@ TEST(OvmsGraphConfigTest, positiveAllChangedEmbeddings) {
17031705
ASSERT_EQ(hfSettings.task, ovms::EMBEDDINGS_GRAPH);
17041706
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings = std::get<ovms::EmbeddingsGraphSettingsImpl>(hfSettings.graphSettings);
17051707
ASSERT_EQ(embeddingsGraphSettings.normalize, "false");
1708+
ASSERT_EQ(embeddingsGraphSettings.truncate, "true");
17061709
ASSERT_EQ(embeddingsGraphSettings.pooling, "CLS");
17071710
ASSERT_EQ(embeddingsGraphSettings.numStreams, 2);
17081711
ASSERT_EQ(embeddingsGraphSettings.targetDevice, "GPU");
@@ -1728,6 +1731,8 @@ TEST(OvmsGraphConfigTest, positiveAllChangedEmbeddingsStart) {
17281731
(char*)"GPU",
17291732
(char*)"--normalize",
17301733
(char*)"false",
1734+
(char*)"--truncate",
1735+
(char*)"true",
17311736
(char*)"--num_streams",
17321737
(char*)"2",
17331738
(char*)"--model_name",
@@ -1736,7 +1741,7 @@ TEST(OvmsGraphConfigTest, positiveAllChangedEmbeddingsStart) {
17361741
(char*)"8080",
17371742
};
17381743

1739-
int arg_count = 19;
1744+
int arg_count = 21;
17401745
ConstructorEnabledConfig config;
17411746
config.parse(arg_count, n_argv);
17421747

@@ -1747,6 +1752,7 @@ TEST(OvmsGraphConfigTest, positiveAllChangedEmbeddingsStart) {
17471752
ASSERT_EQ(hfSettings.task, ovms::EMBEDDINGS_GRAPH);
17481753
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings = std::get<ovms::EmbeddingsGraphSettingsImpl>(hfSettings.graphSettings);
17491754
ASSERT_EQ(embeddingsGraphSettings.normalize, "false");
1755+
ASSERT_EQ(embeddingsGraphSettings.truncate, "true");
17501756
ASSERT_EQ(embeddingsGraphSettings.pooling, "LAST");
17511757
ASSERT_EQ(embeddingsGraphSettings.numStreams, 2);
17521758
ASSERT_EQ(embeddingsGraphSettings.targetDevice, "GPU");
@@ -1779,6 +1785,7 @@ TEST(OvmsGraphConfigTest, positiveDefaultEmbeddings) {
17791785
ASSERT_EQ(hfSettings.task, ovms::EMBEDDINGS_GRAPH);
17801786
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings = std::get<ovms::EmbeddingsGraphSettingsImpl>(hfSettings.graphSettings);
17811787
ASSERT_EQ(embeddingsGraphSettings.normalize, "true");
1788+
ASSERT_EQ(embeddingsGraphSettings.truncate, "false");
17821789
ASSERT_EQ(embeddingsGraphSettings.pooling, "CLS");
17831790
ASSERT_EQ(embeddingsGraphSettings.numStreams, 1);
17841791
ASSERT_EQ(embeddingsGraphSettings.targetDevice, "CPU");

0 commit comments

Comments
 (0)