1717import org .elasticsearch .core .Nullable ;
1818import org .elasticsearch .core .TimeValue ;
1919import org .elasticsearch .inference .ChunkedInference ;
20+ import org .elasticsearch .inference .ChunkingSettings ;
2021import org .elasticsearch .inference .EmptySecretSettings ;
2122import org .elasticsearch .inference .EmptyTaskSettings ;
2223import org .elasticsearch .inference .InferenceServiceConfiguration ;
3637import org .elasticsearch .xpack .core .inference .results .ChunkedInferenceError ;
3738import org .elasticsearch .xpack .core .inference .results .SparseEmbeddingResults ;
3839import org .elasticsearch .xpack .core .ml .inference .results .ErrorInferenceResults ;
40+ import org .elasticsearch .xpack .inference .chunking .ChunkingSettingsBuilder ;
41+ import org .elasticsearch .xpack .inference .chunking .EmbeddingRequestChunker ;
3942import org .elasticsearch .xpack .inference .external .action .SenderExecutableAction ;
4043import org .elasticsearch .xpack .inference .external .http .sender .EmbeddingsInput ;
4144import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSender ;
7174import static org .elasticsearch .xpack .inference .services .ServiceFields .MODEL_ID ;
7275import static org .elasticsearch .xpack .inference .services .ServiceUtils .createInvalidModelException ;
7376import static org .elasticsearch .xpack .inference .services .ServiceUtils .parsePersistedConfigErrorMsg ;
77+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMap ;
7478import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrDefaultEmpty ;
7579import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrThrowIfNull ;
7680import static org .elasticsearch .xpack .inference .services .ServiceUtils .throwIfNotEmptyMap ;
@@ -80,6 +84,7 @@ public class ElasticInferenceService extends SenderService {
8084
8185 public static final String NAME = "elastic" ;
8286 public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service" ;
87+ public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512 ;
8388
8489 private static final EnumSet <TaskType > IMPLEMENTED_TASK_TYPES = EnumSet .of (
8590 TaskType .SPARSE_EMBEDDING ,
@@ -161,7 +166,8 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
161166 new ElasticInferenceServiceSparseEmbeddingsServiceSettings (DEFAULT_ELSER_MODEL_ID_V2 , null , null ),
162167 EmptyTaskSettings .INSTANCE ,
163168 EmptySecretSettings .INSTANCE ,
164- elasticInferenceServiceComponents
169+ elasticInferenceServiceComponents ,
170+ ChunkingSettingsBuilder .DEFAULT_SETTINGS
165171 ),
166172 MinimalServiceSettings .sparseEmbedding (NAME )
167173 ),
@@ -304,12 +310,25 @@ protected void doChunkedInfer(
304310 TimeValue timeout ,
305311 ActionListener <List <ChunkedInference >> listener
306312 ) {
307- // Pass-through without actually performing chunking (result will have a single chunk per input)
308- ActionListener <InferenceServiceResults > inferListener = listener .delegateFailureAndWrap (
309- (delegate , response ) -> delegate .onResponse (translateToChunkedResults (inputs , response ))
310- );
313+ if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel ) {
314+ var actionCreator = new ElasticInferenceServiceActionCreator (getSender (), getServiceComponents (), getCurrentTraceInfo ());
315+
316+ List <EmbeddingRequestChunker .BatchRequestAndListener > batchedRequests = new EmbeddingRequestChunker <>(
317+ inputs .getInputs (),
318+ SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE ,
319+ model .getConfigurations ().getChunkingSettings ()
320+ ).batchRequestsWithListeners (listener );
321+
322+ for (var request : batchedRequests ) {
323+ var action = sparseTextEmbeddingsModel .accept (actionCreator , taskSettings );
324+ action .execute (EmbeddingsInput .fromStrings (request .batch ().inputs ().get (), inputType ), timeout , request .listener ());
325+ }
326+
327+ return ;
328+ }
311329
312- doInfer (model , inputs , taskSettings , timeout , inferListener );
330+ // Model cannot perform chunked inference
331+ listener .onFailure (createInvalidModelException (model ));
313332 }
314333
315334 @ Override
@@ -328,6 +347,13 @@ public void parseRequestConfig(
328347 Map <String , Object > serviceSettingsMap = removeFromMapOrThrowIfNull (config , ModelConfigurations .SERVICE_SETTINGS );
329348 Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
330349
350+ ChunkingSettings chunkingSettings = null ;
351+ if (TaskType .SPARSE_EMBEDDING .equals (taskType )) {
352+ chunkingSettings = ChunkingSettingsBuilder .fromMap (
353+ removeFromMapOrDefaultEmpty (config , ModelConfigurations .CHUNKING_SETTINGS )
354+ );
355+ }
356+
331357 ElasticInferenceServiceModel model = createModel (
332358 inferenceEntityId ,
333359 taskType ,
@@ -336,7 +362,8 @@ public void parseRequestConfig(
336362 serviceSettingsMap ,
337363 elasticInferenceServiceComponents ,
338364 TaskType .unsupportedTaskTypeErrorMsg (taskType , NAME ),
339- ConfigurationParseContext .REQUEST
365+ ConfigurationParseContext .REQUEST ,
366+ chunkingSettings
340367 );
341368
342369 throwIfNotEmptyMap (config , NAME );
@@ -372,7 +399,8 @@ private static ElasticInferenceServiceModel createModel(
372399 @ Nullable Map <String , Object > secretSettings ,
373400 ElasticInferenceServiceComponents elasticInferenceServiceComponents ,
374401 String failureMessage ,
375- ConfigurationParseContext context
402+ ConfigurationParseContext context ,
403+ ChunkingSettings chunkingSettings
376404 ) {
377405 return switch (taskType ) {
378406 case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel (
@@ -383,7 +411,8 @@ private static ElasticInferenceServiceModel createModel(
383411 taskSettings ,
384412 secretSettings ,
385413 elasticInferenceServiceComponents ,
386- context
414+ context ,
415+ chunkingSettings
387416 );
388417 case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel (
389418 inferenceEntityId ,
@@ -420,13 +449,19 @@ public Model parsePersistedConfigWithSecrets(
420449 Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
421450 Map <String , Object > secretSettingsMap = removeFromMapOrDefaultEmpty (secrets , ModelSecrets .SECRET_SETTINGS );
422451
452+ ChunkingSettings chunkingSettings = null ;
453+ if (TaskType .SPARSE_EMBEDDING .equals (taskType )) {
454+ chunkingSettings = ChunkingSettingsBuilder .fromMap (removeFromMap (config , ModelConfigurations .CHUNKING_SETTINGS ));
455+ }
456+
423457 return createModelFromPersistent (
424458 inferenceEntityId ,
425459 taskType ,
426460 serviceSettingsMap ,
427461 taskSettingsMap ,
428462 secretSettingsMap ,
429- parsePersistedConfigErrorMsg (inferenceEntityId , NAME )
463+ parsePersistedConfigErrorMsg (inferenceEntityId , NAME ),
464+ chunkingSettings
430465 );
431466 }
432467
@@ -435,13 +470,19 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
435470 Map <String , Object > serviceSettingsMap = removeFromMapOrThrowIfNull (config , ModelConfigurations .SERVICE_SETTINGS );
436471 Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
437472
473+ ChunkingSettings chunkingSettings = null ;
474+ if (TaskType .SPARSE_EMBEDDING .equals (taskType )) {
475+ chunkingSettings = ChunkingSettingsBuilder .fromMap (removeFromMap (config , ModelConfigurations .CHUNKING_SETTINGS ));
476+ }
477+
438478 return createModelFromPersistent (
439479 inferenceEntityId ,
440480 taskType ,
441481 serviceSettingsMap ,
442482 taskSettingsMap ,
443483 null ,
444- parsePersistedConfigErrorMsg (inferenceEntityId , NAME )
484+ parsePersistedConfigErrorMsg (inferenceEntityId , NAME ),
485+ chunkingSettings
445486 );
446487 }
447488
@@ -456,7 +497,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
456497 Map <String , Object > serviceSettings ,
457498 Map <String , Object > taskSettings ,
458499 @ Nullable Map <String , Object > secretSettings ,
459- String failureMessage
500+ String failureMessage ,
501+ ChunkingSettings chunkingSettings
460502 ) {
461503 return createModel (
462504 inferenceEntityId ,
@@ -466,7 +508,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
466508 secretSettings ,
467509 elasticInferenceServiceComponents ,
468510 failureMessage ,
469- ConfigurationParseContext .PERSISTENT
511+ ConfigurationParseContext .PERSISTENT ,
512+ chunkingSettings
470513 );
471514 }
472515
0 commit comments