Skip to content

Commit b92e3c7

Browse files
Adding chunking settings to IbmWatsonxService (#114914) (#117278)
* Adding chunking settings to IbmWatsonxService * Removing feature flag * Update docs/changelog/114914.yaml --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent 6adf47e commit b92e3c7

File tree

5 files changed

+211
-3
lines changed

5 files changed

+211
-3
lines changed

docs/changelog/114914.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114914
2+
summary: Adding chunking settings to `IbmWatsonxService`
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.core.TimeValue;
1717
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
1818
import org.elasticsearch.inference.ChunkingOptions;
19+
import org.elasticsearch.inference.ChunkingSettings;
1920
import org.elasticsearch.inference.EmptySettingsConfiguration;
2021
import org.elasticsearch.inference.InferenceServiceConfiguration;
2122
import org.elasticsearch.inference.InferenceServiceResults;
@@ -30,6 +31,7 @@
3031
import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType;
3132
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
3233
import org.elasticsearch.rest.RestStatus;
34+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3335
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3436
import org.elasticsearch.xpack.inference.external.action.ibmwatsonx.IbmWatsonxActionCreator;
3537
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
@@ -86,11 +88,19 @@ public void parseRequestConfig(
8688
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
8789
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
8890

91+
ChunkingSettings chunkingSettings = null;
92+
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
93+
chunkingSettings = ChunkingSettingsBuilder.fromMap(
94+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
95+
);
96+
}
97+
8998
IbmWatsonxModel model = createModel(
9099
inferenceEntityId,
91100
taskType,
92101
serviceSettingsMap,
93102
taskSettingsMap,
103+
chunkingSettings,
94104
serviceSettingsMap,
95105
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
96106
ConfigurationParseContext.REQUEST
@@ -112,6 +122,7 @@ private static IbmWatsonxModel createModel(
112122
TaskType taskType,
113123
Map<String, Object> serviceSettings,
114124
Map<String, Object> taskSettings,
125+
ChunkingSettings chunkingSettings,
115126
@Nullable Map<String, Object> secretSettings,
116127
String failureMessage,
117128
ConfigurationParseContext context
@@ -123,6 +134,7 @@ private static IbmWatsonxModel createModel(
123134
NAME,
124135
serviceSettings,
125136
taskSettings,
137+
chunkingSettings,
126138
secretSettings,
127139
context
128140
);
@@ -141,11 +153,17 @@ public IbmWatsonxModel parsePersistedConfigWithSecrets(
141153
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
142154
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
143155

156+
ChunkingSettings chunkingSettings = null;
157+
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
158+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
159+
}
160+
144161
return createModelFromPersistent(
145162
inferenceEntityId,
146163
taskType,
147164
serviceSettingsMap,
148165
taskSettingsMap,
166+
chunkingSettings,
149167
secretSettingsMap,
150168
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
151169
);
@@ -166,6 +184,7 @@ private static IbmWatsonxModel createModelFromPersistent(
166184
TaskType taskType,
167185
Map<String, Object> serviceSettings,
168186
Map<String, Object> taskSettings,
187+
ChunkingSettings chunkingSettings,
169188
Map<String, Object> secretSettings,
170189
String failureMessage
171190
) {
@@ -174,6 +193,7 @@ private static IbmWatsonxModel createModelFromPersistent(
174193
taskType,
175194
serviceSettings,
176195
taskSettings,
196+
chunkingSettings,
177197
secretSettings,
178198
failureMessage,
179199
ConfigurationParseContext.PERSISTENT
@@ -185,11 +205,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
185205
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
186206
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
187207

208+
ChunkingSettings chunkingSettings = null;
209+
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
210+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
211+
}
212+
188213
return createModelFromPersistent(
189214
inferenceEntityId,
190215
taskType,
191216
serviceSettingsMap,
192217
taskSettingsMap,
218+
chunkingSettings,
193219
null,
194220
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
195221
);
@@ -266,7 +292,8 @@ protected void doChunkedInfer(
266292
var batchedRequests = new EmbeddingRequestChunker(
267293
input.getInputs(),
268294
EMBEDDING_MAX_BATCH_SIZE,
269-
EmbeddingRequestChunker.EmbeddingType.FLOAT
295+
EmbeddingRequestChunker.EmbeddingType.FLOAT,
296+
model.getConfigurations().getChunkingSettings()
270297
).batchRequestsWithListeners(listener);
271298
for (var request : batchedRequests) {
272299
var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings, inputType);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsModel.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.apache.http.client.utils.URIBuilder;
1111
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.ChunkingSettings;
1213
import org.elasticsearch.inference.EmptyTaskSettings;
1314
import org.elasticsearch.inference.InputType;
1415
import org.elasticsearch.inference.ModelConfigurations;
@@ -40,6 +41,7 @@ public IbmWatsonxEmbeddingsModel(
4041
String service,
4142
Map<String, Object> serviceSettings,
4243
Map<String, Object> taskSettings,
44+
ChunkingSettings chunkingSettings,
4345
Map<String, Object> secrets,
4446
ConfigurationParseContext context
4547
) {
@@ -49,6 +51,7 @@ public IbmWatsonxEmbeddingsModel(
4951
service,
5052
IbmWatsonxEmbeddingsServiceSettings.fromMap(serviceSettings, context),
5153
EmptyTaskSettings.INSTANCE,
54+
chunkingSettings,
5255
DefaultSecretSettings.fromMap(secrets)
5356
);
5457
}
@@ -64,10 +67,11 @@ public IbmWatsonxEmbeddingsModel(IbmWatsonxEmbeddingsModel model, IbmWatsonxEmbe
6467
String service,
6568
IbmWatsonxEmbeddingsServiceSettings serviceSettings,
6669
TaskSettings taskSettings,
70+
ChunkingSettings chunkingsettings,
6771
@Nullable DefaultSecretSettings secrets
6872
) {
6973
super(
70-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
74+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingsettings),
7175
new ModelSecrets(secrets),
7276
serviceSettings
7377
);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.core.TimeValue;
2121
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
2222
import org.elasticsearch.inference.ChunkingOptions;
23+
import org.elasticsearch.inference.ChunkingSettings;
2324
import org.elasticsearch.inference.EmptyTaskSettings;
2425
import org.elasticsearch.inference.InferenceServiceConfiguration;
2526
import org.elasticsearch.inference.InferenceServiceResults;
@@ -69,6 +70,8 @@
6970
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
7071
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
7172
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
73+
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings;
74+
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
7275
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
7376
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
7477
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
@@ -124,6 +127,7 @@ public void testParseRequestConfig_CreatesAIbmWatsonxEmbeddingsModel() throws IO
124127
assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
125128
assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
126129
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
130+
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
127131
}, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage()));
128132

129133
service.parseRequestConfig(
@@ -150,6 +154,45 @@ public void testParseRequestConfig_CreatesAIbmWatsonxEmbeddingsModel() throws IO
150154
}
151155
}
152156

157+
public void testParseRequestConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
158+
try (var service = createIbmWatsonxService()) {
159+
ActionListener<Model> modelListener = ActionListener.wrap(model -> {
160+
assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
161+
162+
var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
163+
assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
164+
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
165+
assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
166+
assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
167+
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
168+
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
169+
}, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage()));
170+
171+
service.parseRequestConfig(
172+
"id",
173+
TaskType.TEXT_EMBEDDING,
174+
getRequestConfigMap(
175+
new HashMap<>(
176+
Map.of(
177+
ServiceFields.MODEL_ID,
178+
modelId,
179+
IbmWatsonxServiceFields.PROJECT_ID,
180+
projectId,
181+
ServiceFields.URL,
182+
url,
183+
IbmWatsonxServiceFields.API_VERSION,
184+
apiVersion
185+
)
186+
),
187+
new HashMap<>(Map.of()),
188+
createRandomChunkingSettingsMap(),
189+
getSecretSettingsMap(apiKey)
190+
),
191+
modelListener
192+
);
193+
}
194+
}
195+
153196
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
154197
try (var service = createIbmWatsonxService()) {
155198
var failureListener = getModelListenerForException(
@@ -235,6 +278,47 @@ public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsMode
235278
assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
236279
assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
237280
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
281+
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
282+
}
283+
}
284+
285+
public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
286+
try (var service = createIbmWatsonxService()) {
287+
var persistedConfig = getPersistedConfigMap(
288+
new HashMap<>(
289+
Map.of(
290+
ServiceFields.MODEL_ID,
291+
modelId,
292+
IbmWatsonxServiceFields.PROJECT_ID,
293+
projectId,
294+
ServiceFields.URL,
295+
url,
296+
IbmWatsonxServiceFields.API_VERSION,
297+
apiVersion
298+
)
299+
),
300+
getTaskSettingsMapEmpty(),
301+
createRandomChunkingSettingsMap(),
302+
getSecretSettingsMap(apiKey)
303+
);
304+
305+
var model = service.parsePersistedConfigWithSecrets(
306+
"id",
307+
TaskType.TEXT_EMBEDDING,
308+
persistedConfig.config(),
309+
persistedConfig.secrets()
310+
);
311+
312+
assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
313+
314+
var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
315+
assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
316+
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
317+
assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
318+
assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
319+
assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
320+
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
321+
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
238322
}
239323
}
240324

@@ -399,6 +483,73 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
399483
}
400484
}
401485

486+
public void testParsePersistedConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
487+
try (var service = createIbmWatsonxService()) {
488+
var persistedConfig = getPersistedConfigMap(
489+
new HashMap<>(
490+
Map.of(
491+
ServiceFields.MODEL_ID,
492+
modelId,
493+
IbmWatsonxServiceFields.PROJECT_ID,
494+
projectId,
495+
ServiceFields.URL,
496+
url,
497+
IbmWatsonxServiceFields.API_VERSION,
498+
apiVersion
499+
)
500+
),
501+
getTaskSettingsMapEmpty(),
502+
null
503+
);
504+
505+
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
506+
507+
assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
508+
509+
var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
510+
assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
511+
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
512+
assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
513+
assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
514+
assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
515+
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
516+
}
517+
}
518+
519+
public void testParsePersistedConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
520+
try (var service = createIbmWatsonxService()) {
521+
var persistedConfig = getPersistedConfigMap(
522+
new HashMap<>(
523+
Map.of(
524+
ServiceFields.MODEL_ID,
525+
modelId,
526+
IbmWatsonxServiceFields.PROJECT_ID,
527+
projectId,
528+
ServiceFields.URL,
529+
url,
530+
IbmWatsonxServiceFields.API_VERSION,
531+
apiVersion
532+
)
533+
),
534+
getTaskSettingsMapEmpty(),
535+
createRandomChunkingSettingsMap(),
536+
null
537+
);
538+
539+
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
540+
541+
assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
542+
543+
var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
544+
assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
545+
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
546+
assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
547+
assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
548+
assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
549+
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
550+
}
551+
}
552+
402553
public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOException {
403554
var sender = mock(Sender.class);
404555

@@ -488,7 +639,15 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException {
488639
}
489640
}
490641

491-
public void testChunkedInfer_Batches() throws IOException {
642+
public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
643+
testChunkedInfer_Batches(null);
644+
}
645+
646+
public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
647+
testChunkedInfer_Batches(createRandomChunkingSettings());
648+
}
649+
650+
private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException {
492651
var input = List.of("foo", "bar");
493652

494653
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
@@ -878,6 +1037,18 @@ private static ActionListener<Model> getModelListenerForException(Class<?> excep
8781037
});
8791038
}
8801039

1040+
private Map<String, Object> getRequestConfigMap(
1041+
Map<String, Object> serviceSettings,
1042+
Map<String, Object> taskSettings,
1043+
Map<String, Object> chunkingSettings,
1044+
Map<String, Object> secretSettings
1045+
) {
1046+
var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings);
1047+
requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings);
1048+
1049+
return requestConfigMap;
1050+
}
1051+
8811052
private Map<String, Object> getRequestConfigMap(
8821053
Map<String, Object> serviceSettings,
8831054
Map<String, Object> taskSettings,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsModelTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ public static IbmWatsonxEmbeddingsModel createModel(
8282
null
8383
),
8484
EmptyTaskSettings.INSTANCE,
85+
null,
8586
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
8687
);
8788
}

0 commit comments

Comments
 (0)