1717import org .elasticsearch .core .Nullable ;
1818import org .elasticsearch .core .TimeValue ;
1919import org .elasticsearch .inference .ChunkedInference ;
20+ import org .elasticsearch .inference .ChunkingSettings ;
2021import org .elasticsearch .inference .EmptySecretSettings ;
2122import org .elasticsearch .inference .EmptyTaskSettings ;
2223import org .elasticsearch .inference .InferenceServiceConfiguration ;
3637import org .elasticsearch .xpack .core .inference .results .ChunkedInferenceError ;
3738import org .elasticsearch .xpack .core .inference .results .SparseEmbeddingResults ;
3839import org .elasticsearch .xpack .core .ml .inference .results .ErrorInferenceResults ;
40+ import org .elasticsearch .xpack .inference .chunking .ChunkingSettingsBuilder ;
41+ import org .elasticsearch .xpack .inference .chunking .EmbeddingRequestChunker ;
3942import org .elasticsearch .xpack .inference .external .action .SenderExecutableAction ;
4043import org .elasticsearch .xpack .inference .external .http .sender .EmbeddingsInput ;
4144import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSender ;
6871import static org .elasticsearch .xpack .inference .services .ServiceFields .MODEL_ID ;
6972import static org .elasticsearch .xpack .inference .services .ServiceUtils .createInvalidModelException ;
7073import static org .elasticsearch .xpack .inference .services .ServiceUtils .parsePersistedConfigErrorMsg ;
74+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMap ;
7175import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrDefaultEmpty ;
7276import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrThrowIfNull ;
7377import static org .elasticsearch .xpack .inference .services .ServiceUtils .throwIfNotEmptyMap ;
@@ -77,6 +81,7 @@ public class ElasticInferenceService extends SenderService {
7781
7882 public static final String NAME = "elastic" ;
7983 public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service" ;
84+ public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512 ;
8085
8186 private static final EnumSet <TaskType > IMPLEMENTED_TASK_TYPES = EnumSet .of (
8287 TaskType .SPARSE_EMBEDDING ,
@@ -154,7 +159,8 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
154159 new ElasticInferenceServiceSparseEmbeddingsServiceSettings (DEFAULT_ELSER_MODEL_ID_V2 , null , null ),
155160 EmptyTaskSettings .INSTANCE ,
156161 EmptySecretSettings .INSTANCE ,
157- elasticInferenceServiceComponents
162+ elasticInferenceServiceComponents ,
163+ ChunkingSettingsBuilder .DEFAULT_SETTINGS
158164 ),
159165 MinimalServiceSettings .sparseEmbedding (NAME )
160166 )
@@ -284,12 +290,25 @@ protected void doChunkedInfer(
284290 TimeValue timeout ,
285291 ActionListener <List <ChunkedInference >> listener
286292 ) {
287- // Pass-through without actually performing chunking (result will have a single chunk per input)
288- ActionListener <InferenceServiceResults > inferListener = listener .delegateFailureAndWrap (
289- (delegate , response ) -> delegate .onResponse (translateToChunkedResults (inputs , response ))
290- );
293+ if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel ) {
294+ var actionCreator = new ElasticInferenceServiceActionCreator (getSender (), getServiceComponents (), getCurrentTraceInfo ());
295+
296+ List <EmbeddingRequestChunker .BatchRequestAndListener > batchedRequests = new EmbeddingRequestChunker <>(
297+ inputs .getInputs (),
298+ SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE ,
299+ model .getConfigurations ().getChunkingSettings ()
300+ ).batchRequestsWithListeners (listener );
301+
302+ for (var request : batchedRequests ) {
303+ var action = sparseTextEmbeddingsModel .accept (actionCreator , taskSettings );
304+ action .execute (EmbeddingsInput .fromStrings (request .batch ().inputs ().get (), inputType ), timeout , request .listener ());
305+ }
306+
307+ return ;
308+ }
291309
292- doInfer (model , inputs , taskSettings , timeout , inferListener );
310+ // Model cannot perform chunked inference
311+ listener .onFailure (createInvalidModelException (model ));
293312 }
294313
295314 @ Override
@@ -308,6 +327,13 @@ public void parseRequestConfig(
308327 Map <String , Object > serviceSettingsMap = removeFromMapOrThrowIfNull (config , ModelConfigurations .SERVICE_SETTINGS );
309328 Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
310329
330+ ChunkingSettings chunkingSettings = null ;
331+ if (TaskType .SPARSE_EMBEDDING .equals (taskType )) {
332+ chunkingSettings = ChunkingSettingsBuilder .fromMap (
333+ removeFromMapOrDefaultEmpty (config , ModelConfigurations .CHUNKING_SETTINGS )
334+ );
335+ }
336+
311337 ElasticInferenceServiceModel model = createModel (
312338 inferenceEntityId ,
313339 taskType ,
@@ -316,7 +342,8 @@ public void parseRequestConfig(
316342 serviceSettingsMap ,
317343 elasticInferenceServiceComponents ,
318344 TaskType .unsupportedTaskTypeErrorMsg (taskType , NAME ),
319- ConfigurationParseContext .REQUEST
345+ ConfigurationParseContext .REQUEST ,
346+ chunkingSettings
320347 );
321348
322349 throwIfNotEmptyMap (config , NAME );
@@ -352,7 +379,8 @@ private static ElasticInferenceServiceModel createModel(
352379 @ Nullable Map <String , Object > secretSettings ,
353380 ElasticInferenceServiceComponents elasticInferenceServiceComponents ,
354381 String failureMessage ,
355- ConfigurationParseContext context
382+ ConfigurationParseContext context ,
383+ ChunkingSettings chunkingSettings
356384 ) {
357385 return switch (taskType ) {
358386 case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel (
@@ -363,7 +391,8 @@ private static ElasticInferenceServiceModel createModel(
363391 taskSettings ,
364392 secretSettings ,
365393 elasticInferenceServiceComponents ,
366- context
394+ context ,
395+ chunkingSettings
367396 );
368397 case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel (
369398 inferenceEntityId ,
@@ -400,13 +429,19 @@ public Model parsePersistedConfigWithSecrets(
400429 Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
401430 Map <String , Object > secretSettingsMap = removeFromMapOrDefaultEmpty (secrets , ModelSecrets .SECRET_SETTINGS );
402431
432+ ChunkingSettings chunkingSettings = null ;
433+ if (TaskType .SPARSE_EMBEDDING .equals (taskType )) {
434+ chunkingSettings = ChunkingSettingsBuilder .fromMap (removeFromMap (config , ModelConfigurations .CHUNKING_SETTINGS ));
435+ }
436+
403437 return createModelFromPersistent (
404438 inferenceEntityId ,
405439 taskType ,
406440 serviceSettingsMap ,
407441 taskSettingsMap ,
408442 secretSettingsMap ,
409- parsePersistedConfigErrorMsg (inferenceEntityId , NAME )
443+ parsePersistedConfigErrorMsg (inferenceEntityId , NAME ),
444+ chunkingSettings
410445 );
411446 }
412447
@@ -415,13 +450,19 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
415450 Map <String , Object > serviceSettingsMap = removeFromMapOrThrowIfNull (config , ModelConfigurations .SERVICE_SETTINGS );
416451 Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
417452
453+ ChunkingSettings chunkingSettings = null ;
454+ if (TaskType .SPARSE_EMBEDDING .equals (taskType )) {
455+ chunkingSettings = ChunkingSettingsBuilder .fromMap (removeFromMap (config , ModelConfigurations .CHUNKING_SETTINGS ));
456+ }
457+
418458 return createModelFromPersistent (
419459 inferenceEntityId ,
420460 taskType ,
421461 serviceSettingsMap ,
422462 taskSettingsMap ,
423463 null ,
424- parsePersistedConfigErrorMsg (inferenceEntityId , NAME )
464+ parsePersistedConfigErrorMsg (inferenceEntityId , NAME ),
465+ chunkingSettings
425466 );
426467 }
427468
@@ -436,7 +477,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
436477 Map <String , Object > serviceSettings ,
437478 Map <String , Object > taskSettings ,
438479 @ Nullable Map <String , Object > secretSettings ,
439- String failureMessage
480+ String failureMessage ,
481+ ChunkingSettings chunkingSettings
440482 ) {
441483 return createModel (
442484 inferenceEntityId ,
@@ -446,7 +488,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
446488 secretSettings ,
447489 elasticInferenceServiceComponents ,
448490 failureMessage ,
449- ConfigurationParseContext .PERSISTENT
491+ ConfigurationParseContext .PERSISTENT ,
492+ chunkingSettings
450493 );
451494 }
452495
0 commit comments