Skip to content

Commit a217498

Browse files
Add chunking settings configuration to CohereService, AmazonBedrockService, and AzureOpenAiService (#113897) (#114318)
* Add chunking settings configuration to CohereService, AmazonBedrockService, and AzureOpenAiService * Update docs/changelog/113897.yaml * Run spotlessApply * Updating CohereServiceMixedIT to account for clusters without chunking settings in index mapping --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent 01b0a58 commit a217498

File tree

14 files changed

+1215
-40
lines changed

14 files changed

+1215
-40
lines changed

docs/changelog/113897.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 113897
2+
summary: "Add chunking settings configuration to `CohereService,` `AmazonBedrockService,`\
3+
\ and `AzureOpenAiService`"
4+
area: Machine Learning
5+
type: enhancement
6+
issues: []

x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Map;
2323

2424
import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
25+
import static org.hamcrest.Matchers.containsString;
2526
import static org.hamcrest.Matchers.empty;
2627
import static org.hamcrest.Matchers.hasEntry;
2728
import static org.hamcrest.Matchers.hasSize;
@@ -32,6 +33,7 @@ public class CohereServiceMixedIT extends BaseMixedTestCase {
3233

3334
private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0";
3435
private static final String COHERE_RERANK_ADDED = "8.14.0";
36+
private static final String COHERE_EMBEDDINGS_CHUNKING_SETTINGS_ADDED = "8.16.0";
3537
private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0";
3638
private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
3739

@@ -65,13 +67,28 @@ public void testCohereEmbeddings() throws IOException {
6567
final String inferenceIdInt8 = "mixed-cluster-cohere-embeddings-int8";
6668
final String inferenceIdFloat = "mixed-cluster-cohere-embeddings-float";
6769

68-
// queue a response as PUT will call the service
69-
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
70-
put(inferenceIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
71-
72-
// float model
73-
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
74-
put(inferenceIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
70+
try {
71+
// queue a response as PUT will call the service
72+
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
73+
put(inferenceIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
74+
75+
// float model
76+
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
77+
put(inferenceIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
78+
} catch (Exception e) {
79+
if (bwcVersion.before(Version.fromString(COHERE_EMBEDDINGS_CHUNKING_SETTINGS_ADDED))) {
80+
// Chunking settings were added in 8.16.0. if the version is before that, an exception will be thrown if the index mapping
81+
// was created based on a mapping from an old node
82+
assertThat(
83+
e.getMessage(),
84+
containsString(
85+
"One or more nodes in your cluster does not support chunking_settings. "
86+
+ "Please update all nodes in your cluster to the latest version to use chunking_settings."
87+
)
88+
);
89+
return;
90+
}
91+
}
7592

7693
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdInt8).get("endpoints");
7794
assertEquals("cohere", configs.get(0).get("service"));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
import org.elasticsearch.core.TimeValue;
1717
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
1818
import org.elasticsearch.inference.ChunkingOptions;
19+
import org.elasticsearch.inference.ChunkingSettings;
1920
import org.elasticsearch.inference.InferenceServiceResults;
2021
import org.elasticsearch.inference.InputType;
2122
import org.elasticsearch.inference.Model;
2223
import org.elasticsearch.inference.ModelConfigurations;
2324
import org.elasticsearch.inference.ModelSecrets;
2425
import org.elasticsearch.inference.TaskType;
2526
import org.elasticsearch.rest.RestStatus;
27+
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
28+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
2629
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
2730
import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionCreator;
2831
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
@@ -99,8 +102,20 @@ protected void doChunkedInfer(
99102
var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout);
100103
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
101104
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
102-
var batchedRequests = new EmbeddingRequestChunker(inputs.getInputs(), maxBatchSize, EmbeddingRequestChunker.EmbeddingType.FLOAT)
103-
.batchRequestsWithListeners(listener);
105+
106+
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
107+
if (ChunkingSettingsFeatureFlag.isEnabled()) {
108+
batchedRequests = new EmbeddingRequestChunker(
109+
inputs.getInputs(),
110+
maxBatchSize,
111+
EmbeddingRequestChunker.EmbeddingType.FLOAT,
112+
baseAmazonBedrockModel.getConfigurations().getChunkingSettings()
113+
).batchRequestsWithListeners(listener);
114+
} else {
115+
batchedRequests = new EmbeddingRequestChunker(inputs.getInputs(), maxBatchSize, EmbeddingRequestChunker.EmbeddingType.FLOAT)
116+
.batchRequestsWithListeners(listener);
117+
}
118+
104119
for (var request : batchedRequests) {
105120
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
106121
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
@@ -126,11 +141,19 @@ public void parseRequestConfig(
126141
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
127142
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
128143

144+
ChunkingSettings chunkingSettings = null;
145+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
146+
chunkingSettings = ChunkingSettingsBuilder.fromMap(
147+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
148+
);
149+
}
150+
129151
AmazonBedrockModel model = createModel(
130152
modelId,
131153
taskType,
132154
serviceSettingsMap,
133155
taskSettingsMap,
156+
chunkingSettings,
134157
serviceSettingsMap,
135158
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
136159
ConfigurationParseContext.REQUEST
@@ -157,11 +180,17 @@ public Model parsePersistedConfigWithSecrets(
157180
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
158181
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
159182

183+
ChunkingSettings chunkingSettings = null;
184+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
185+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
186+
}
187+
160188
return createModel(
161189
modelId,
162190
taskType,
163191
serviceSettingsMap,
164192
taskSettingsMap,
193+
chunkingSettings,
165194
secretSettingsMap,
166195
parsePersistedConfigErrorMsg(modelId, NAME),
167196
ConfigurationParseContext.PERSISTENT
@@ -173,11 +202,17 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
173202
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
174203
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
175204

205+
ChunkingSettings chunkingSettings = null;
206+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
207+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
208+
}
209+
176210
return createModel(
177211
modelId,
178212
taskType,
179213
serviceSettingsMap,
180214
taskSettingsMap,
215+
chunkingSettings,
181216
null,
182217
parsePersistedConfigErrorMsg(modelId, NAME),
183218
ConfigurationParseContext.PERSISTENT
@@ -189,6 +224,7 @@ private static AmazonBedrockModel createModel(
189224
TaskType taskType,
190225
Map<String, Object> serviceSettings,
191226
Map<String, Object> taskSettings,
227+
ChunkingSettings chunkingSettings,
192228
@Nullable Map<String, Object> secretSettings,
193229
String failureMessage,
194230
ConfigurationParseContext context
@@ -201,6 +237,7 @@ private static AmazonBedrockModel createModel(
201237
NAME,
202238
serviceSettings,
203239
taskSettings,
240+
chunkingSettings,
204241
secretSettings,
205242
context
206243
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
99

1010
import org.elasticsearch.common.ValidationException;
11+
import org.elasticsearch.inference.ChunkingSettings;
1112
import org.elasticsearch.inference.EmptyTaskSettings;
1213
import org.elasticsearch.inference.Model;
1314
import org.elasticsearch.inference.ModelConfigurations;
@@ -42,6 +43,7 @@ public AmazonBedrockEmbeddingsModel(
4243
String service,
4344
Map<String, Object> serviceSettings,
4445
Map<String, Object> taskSettings,
46+
ChunkingSettings chunkingSettings,
4547
Map<String, Object> secretSettings,
4648
ConfigurationParseContext context
4749
) {
@@ -51,6 +53,7 @@ public AmazonBedrockEmbeddingsModel(
5153
service,
5254
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
5355
new EmptyTaskSettings(),
56+
chunkingSettings,
5457
AmazonBedrockSecretSettings.fromMap(secretSettings)
5558
);
5659
}
@@ -61,10 +64,11 @@ public AmazonBedrockEmbeddingsModel(
6164
String service,
6265
AmazonBedrockEmbeddingsServiceSettings serviceSettings,
6366
TaskSettings taskSettings,
67+
ChunkingSettings chunkingSettings,
6468
AmazonBedrockSecretSettings secrets
6569
) {
6670
super(
67-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()),
71+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
6872
new ModelSecrets(secrets)
6973
);
7074
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.core.TimeValue;
1717
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
1818
import org.elasticsearch.inference.ChunkingOptions;
19+
import org.elasticsearch.inference.ChunkingSettings;
1920
import org.elasticsearch.inference.InferenceServiceResults;
2021
import org.elasticsearch.inference.InputType;
2122
import org.elasticsearch.inference.Model;
@@ -24,6 +25,8 @@
2425
import org.elasticsearch.inference.SimilarityMeasure;
2526
import org.elasticsearch.inference.TaskType;
2627
import org.elasticsearch.rest.RestStatus;
28+
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
29+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
2730
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
2831
import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator;
2932
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
@@ -70,11 +73,19 @@ public void parseRequestConfig(
7073
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
7174
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
7275

76+
ChunkingSettings chunkingSettings = null;
77+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
78+
chunkingSettings = ChunkingSettingsBuilder.fromMap(
79+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
80+
);
81+
}
82+
7383
AzureOpenAiModel model = createModel(
7484
inferenceEntityId,
7585
taskType,
7686
serviceSettingsMap,
7787
taskSettingsMap,
88+
chunkingSettings,
7889
serviceSettingsMap,
7990
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
8091
ConfigurationParseContext.REQUEST
@@ -95,6 +106,7 @@ private static AzureOpenAiModel createModelFromPersistent(
95106
TaskType taskType,
96107
Map<String, Object> serviceSettings,
97108
Map<String, Object> taskSettings,
109+
ChunkingSettings chunkingSettings,
98110
@Nullable Map<String, Object> secretSettings,
99111
String failureMessage
100112
) {
@@ -103,6 +115,7 @@ private static AzureOpenAiModel createModelFromPersistent(
103115
taskType,
104116
serviceSettings,
105117
taskSettings,
118+
chunkingSettings,
106119
secretSettings,
107120
failureMessage,
108121
ConfigurationParseContext.PERSISTENT
@@ -114,6 +127,7 @@ private static AzureOpenAiModel createModel(
114127
TaskType taskType,
115128
Map<String, Object> serviceSettings,
116129
Map<String, Object> taskSettings,
130+
ChunkingSettings chunkingSettings,
117131
@Nullable Map<String, Object> secretSettings,
118132
String failureMessage,
119133
ConfigurationParseContext context
@@ -126,6 +140,7 @@ private static AzureOpenAiModel createModel(
126140
NAME,
127141
serviceSettings,
128142
taskSettings,
143+
chunkingSettings,
129144
secretSettings,
130145
context
131146
);
@@ -156,11 +171,17 @@ public AzureOpenAiModel parsePersistedConfigWithSecrets(
156171
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
157172
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
158173

174+
ChunkingSettings chunkingSettings = null;
175+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
176+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
177+
}
178+
159179
return createModelFromPersistent(
160180
inferenceEntityId,
161181
taskType,
162182
serviceSettingsMap,
163183
taskSettingsMap,
184+
chunkingSettings,
164185
secretSettingsMap,
165186
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
166187
);
@@ -171,11 +192,17 @@ public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType
171192
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
172193
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
173194

195+
ChunkingSettings chunkingSettings = null;
196+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
197+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
198+
}
199+
174200
return createModelFromPersistent(
175201
inferenceEntityId,
176202
taskType,
177203
serviceSettingsMap,
178204
taskSettingsMap,
205+
chunkingSettings,
179206
null,
180207
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
181208
);
@@ -218,11 +245,23 @@ protected void doChunkedInfer(
218245
}
219246
AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model;
220247
var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());
221-
var batchedRequests = new EmbeddingRequestChunker(
222-
inputs.getInputs(),
223-
EMBEDDING_MAX_BATCH_SIZE,
224-
EmbeddingRequestChunker.EmbeddingType.FLOAT
225-
).batchRequestsWithListeners(listener);
248+
249+
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
250+
if (ChunkingSettingsFeatureFlag.isEnabled()) {
251+
batchedRequests = new EmbeddingRequestChunker(
252+
inputs.getInputs(),
253+
EMBEDDING_MAX_BATCH_SIZE,
254+
EmbeddingRequestChunker.EmbeddingType.FLOAT,
255+
azureOpenAiModel.getConfigurations().getChunkingSettings()
256+
).batchRequestsWithListeners(listener);
257+
} else {
258+
batchedRequests = new EmbeddingRequestChunker(
259+
inputs.getInputs(),
260+
EMBEDDING_MAX_BATCH_SIZE,
261+
EmbeddingRequestChunker.EmbeddingType.FLOAT
262+
).batchRequestsWithListeners(listener);
263+
}
264+
226265
for (var request : batchedRequests) {
227266
var action = azureOpenAiModel.accept(actionCreator, taskSettings);
228267
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.ChunkingSettings;
1112
import org.elasticsearch.inference.ModelConfigurations;
1213
import org.elasticsearch.inference.ModelSecrets;
1314
import org.elasticsearch.inference.TaskType;
@@ -38,6 +39,7 @@ public AzureOpenAiEmbeddingsModel(
3839
String service,
3940
Map<String, Object> serviceSettings,
4041
Map<String, Object> taskSettings,
42+
ChunkingSettings chunkingSettings,
4143
@Nullable Map<String, Object> secrets,
4244
ConfigurationParseContext context
4345
) {
@@ -47,6 +49,7 @@ public AzureOpenAiEmbeddingsModel(
4749
service,
4850
AzureOpenAiEmbeddingsServiceSettings.fromMap(serviceSettings, context),
4951
AzureOpenAiEmbeddingsTaskSettings.fromMap(taskSettings),
52+
chunkingSettings,
5053
AzureOpenAiSecretSettings.fromMap(secrets)
5154
);
5255
}
@@ -58,10 +61,11 @@ public AzureOpenAiEmbeddingsModel(
5861
String service,
5962
AzureOpenAiEmbeddingsServiceSettings serviceSettings,
6063
AzureOpenAiEmbeddingsTaskSettings taskSettings,
64+
ChunkingSettings chunkingSettings,
6165
@Nullable AzureOpenAiSecretSettings secrets
6266
) {
6367
super(
64-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
68+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
6569
new ModelSecrets(secrets),
6670
serviceSettings
6771
);

0 commit comments

Comments
 (0)