Skip to content

Commit 870d581

Browse files
authored
[EIS] Implement chunked & batched inference for sparse text embeddings (#129922)
1 parent 9c52106 commit 870d581

File tree

5 files changed

+71
-20
lines changed

5 files changed

+71
-20
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.Nullable;
1818
import org.elasticsearch.core.TimeValue;
1919
import org.elasticsearch.inference.ChunkedInference;
20+
import org.elasticsearch.inference.ChunkingSettings;
2021
import org.elasticsearch.inference.EmptySecretSettings;
2122
import org.elasticsearch.inference.EmptyTaskSettings;
2223
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -36,6 +37,8 @@
3637
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
3738
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3839
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
40+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
41+
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3942
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
4043
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
4144
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -71,6 +74,7 @@
7174
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
7275
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
7376
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
77+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
7478
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
7579
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
7680
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -80,6 +84,7 @@ public class ElasticInferenceService extends SenderService {
8084

8185
public static final String NAME = "elastic";
8286
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
87+
public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512;
8388

8489
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
8590
TaskType.SPARSE_EMBEDDING,
@@ -161,7 +166,8 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
161166
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
162167
EmptyTaskSettings.INSTANCE,
163168
EmptySecretSettings.INSTANCE,
164-
elasticInferenceServiceComponents
169+
elasticInferenceServiceComponents,
170+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
165171
),
166172
MinimalServiceSettings.sparseEmbedding(NAME)
167173
),
@@ -304,12 +310,25 @@ protected void doChunkedInfer(
304310
TimeValue timeout,
305311
ActionListener<List<ChunkedInference>> listener
306312
) {
307-
// Pass-through without actually performing chunking (result will have a single chunk per input)
308-
ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
309-
(delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response))
310-
);
313+
if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) {
314+
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
315+
316+
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
317+
inputs.getInputs(),
318+
SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE,
319+
model.getConfigurations().getChunkingSettings()
320+
).batchRequestsWithListeners(listener);
321+
322+
for (var request : batchedRequests) {
323+
var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings);
324+
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
325+
}
326+
327+
return;
328+
}
311329

312-
doInfer(model, inputs, taskSettings, timeout, inferListener);
330+
// Model cannot perform chunked inference
331+
listener.onFailure(createInvalidModelException(model));
313332
}
314333

315334
@Override
@@ -328,6 +347,13 @@ public void parseRequestConfig(
328347
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
329348
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
330349

350+
ChunkingSettings chunkingSettings = null;
351+
if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
352+
chunkingSettings = ChunkingSettingsBuilder.fromMap(
353+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
354+
);
355+
}
356+
331357
ElasticInferenceServiceModel model = createModel(
332358
inferenceEntityId,
333359
taskType,
@@ -336,7 +362,8 @@ public void parseRequestConfig(
336362
serviceSettingsMap,
337363
elasticInferenceServiceComponents,
338364
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
339-
ConfigurationParseContext.REQUEST
365+
ConfigurationParseContext.REQUEST,
366+
chunkingSettings
340367
);
341368

342369
throwIfNotEmptyMap(config, NAME);
@@ -372,7 +399,8 @@ private static ElasticInferenceServiceModel createModel(
372399
@Nullable Map<String, Object> secretSettings,
373400
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
374401
String failureMessage,
375-
ConfigurationParseContext context
402+
ConfigurationParseContext context,
403+
ChunkingSettings chunkingSettings
376404
) {
377405
return switch (taskType) {
378406
case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel(
@@ -383,7 +411,8 @@ private static ElasticInferenceServiceModel createModel(
383411
taskSettings,
384412
secretSettings,
385413
elasticInferenceServiceComponents,
386-
context
414+
context,
415+
chunkingSettings
387416
);
388417
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
389418
inferenceEntityId,
@@ -420,13 +449,19 @@ public Model parsePersistedConfigWithSecrets(
420449
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
421450
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
422451

452+
ChunkingSettings chunkingSettings = null;
453+
if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
454+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
455+
}
456+
423457
return createModelFromPersistent(
424458
inferenceEntityId,
425459
taskType,
426460
serviceSettingsMap,
427461
taskSettingsMap,
428462
secretSettingsMap,
429-
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
463+
parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
464+
chunkingSettings
430465
);
431466
}
432467

@@ -435,13 +470,19 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
435470
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
436471
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
437472

473+
ChunkingSettings chunkingSettings = null;
474+
if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
475+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
476+
}
477+
438478
return createModelFromPersistent(
439479
inferenceEntityId,
440480
taskType,
441481
serviceSettingsMap,
442482
taskSettingsMap,
443483
null,
444-
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
484+
parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
485+
chunkingSettings
445486
);
446487
}
447488

@@ -456,7 +497,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
456497
Map<String, Object> serviceSettings,
457498
Map<String, Object> taskSettings,
458499
@Nullable Map<String, Object> secretSettings,
459-
String failureMessage
500+
String failureMessage,
501+
ChunkingSettings chunkingSettings
460502
) {
461503
return createModel(
462504
inferenceEntityId,
@@ -466,7 +508,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
466508
secretSettings,
467509
elasticInferenceServiceComponents,
468510
failureMessage,
469-
ConfigurationParseContext.PERSISTENT
511+
ConfigurationParseContext.PERSISTENT,
512+
chunkingSettings
470513
);
471514
}
472515

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.ChunkingSettings;
1213
import org.elasticsearch.inference.EmptySecretSettings;
1314
import org.elasticsearch.inference.EmptyTaskSettings;
1415
import org.elasticsearch.inference.ModelConfigurations;
@@ -39,7 +40,8 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
3940
Map<String, Object> taskSettings,
4041
Map<String, Object> secrets,
4142
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
42-
ConfigurationParseContext context
43+
ConfigurationParseContext context,
44+
ChunkingSettings chunkingSettings
4345
) {
4446
this(
4547
inferenceEntityId,
@@ -48,7 +50,8 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
4850
ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(serviceSettings, context),
4951
EmptyTaskSettings.INSTANCE,
5052
EmptySecretSettings.INSTANCE,
51-
elasticInferenceServiceComponents
53+
elasticInferenceServiceComponents,
54+
chunkingSettings
5255
);
5356
}
5457

@@ -67,10 +70,11 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
6770
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings,
6871
@Nullable TaskSettings taskSettings,
6972
@Nullable SecretSettings secretSettings,
70-
ElasticInferenceServiceComponents elasticInferenceServiceComponents
73+
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
74+
ChunkingSettings chunkingSettings
7175
) {
7276
super(
73-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
77+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
7478
new ModelSecrets(secretSettings),
7579
serviceSettings,
7680
elasticInferenceServiceComponents

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.inference.EmptyTaskSettings;
1212
import org.elasticsearch.inference.TaskType;
1313
import org.elasticsearch.test.ESTestCase;
14+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
1415
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
1516
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
1617

@@ -28,7 +29,8 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur
2829
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
2930
EmptyTaskSettings.INSTANCE,
3031
EmptySecretSettings.INSTANCE,
31-
ElasticInferenceServiceComponents.of(url)
32+
ElasticInferenceServiceComponents.of(url),
33+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
3234
);
3335
}
3436
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws I
835835
}
836836
}
837837

838-
public void testChunkedInfer_PassesThrough() throws IOException {
838+
public void testChunkedInfer() throws IOException {
839839
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
840840
var elasticInferenceServiceURL = getUrl(webServer);
841841

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.test.ESSingleNodeTestCase;
2222
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
2323
import org.elasticsearch.xpack.inference.Utils;
24+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
2425
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2526
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2627
import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig;
@@ -196,7 +197,8 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints() {
196197
new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-v2", null, null),
197198
EmptyTaskSettings.INSTANCE,
198199
EmptySecretSettings.INSTANCE,
199-
ElasticInferenceServiceComponents.EMPTY_INSTANCE
200+
ElasticInferenceServiceComponents.EMPTY_INSTANCE,
201+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
200202
),
201203
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME)
202204
)

0 commit comments

Comments
 (0)