2828import org .elasticsearch .inference .TaskType ;
2929import org .elasticsearch .inference .UnparsedModel ;
3030import org .elasticsearch .rest .RestStatus ;
31- import org .elasticsearch .xpack .core .inference .results .ErrorChunkedInferenceResults ;
32- import org .elasticsearch .xpack .core .inference .results .InferenceChunkedSparseEmbeddingResults ;
33- import org .elasticsearch .xpack .core .inference .results .InferenceChunkedTextEmbeddingFloatResults ;
3431import org .elasticsearch .xpack .core .inference .results .InferenceTextEmbeddingFloatResults ;
3532import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
3633import org .elasticsearch .xpack .core .inference .results .SparseEmbeddingResults ;
3734import org .elasticsearch .xpack .core .ml .action .GetTrainedModelsAction ;
3835import org .elasticsearch .xpack .core .ml .action .InferModelAction ;
3936import org .elasticsearch .xpack .core .ml .inference .results .ErrorInferenceResults ;
40- import org .elasticsearch .xpack .core .ml .inference .results .MlChunkedTextEmbeddingFloatResults ;
41- import org .elasticsearch .xpack .core .ml .inference .results .MlChunkedTextExpansionResults ;
37+ import org .elasticsearch .xpack .core .ml .inference .results .MlTextEmbeddingResults ;
38+ import org .elasticsearch .xpack .core .ml .inference .results .TextExpansionResults ;
39+ import org .elasticsearch .xpack .core .ml .inference .trainedmodel .EmptyConfigUpdate ;
4240import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextEmbeddingConfigUpdate ;
4341import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextExpansionConfigUpdate ;
4442import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextSimilarityConfigUpdate ;
45- import org .elasticsearch .xpack .core . ml . inference .trainedmodel . TokenizationConfigUpdate ;
43+ import org .elasticsearch .xpack .inference .chunking . EmbeddingRequestChunker ;
4644import org .elasticsearch .xpack .inference .services .ConfigurationParseContext ;
4745import org .elasticsearch .xpack .inference .services .ServiceUtils ;
4846
@@ -74,6 +72,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
7472 MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86
7573 );
7674
75+ public static final int EMBEDDING_MAX_BATCH_SIZE = 10 ;
7776 public static final String DEFAULT_ELSER_ID = ".elser-2" ;
7877
7978 private static final Logger logger = LogManager .getLogger (ElasticsearchInternalService .class );
@@ -501,8 +500,7 @@ public void inferTextEmbedding(
501500 TextEmbeddingConfigUpdate .EMPTY_INSTANCE ,
502501 inputs ,
503502 inputType ,
504- timeout ,
505- false
503+ timeout
506504 );
507505
508506 ActionListener <InferModelAction .Response > mlResultsListener = listener .delegateFailureAndWrap (
@@ -528,8 +526,7 @@ public void inferSparseEmbedding(
528526 TextExpansionConfigUpdate .EMPTY_UPDATE ,
529527 inputs ,
530528 inputType ,
531- timeout ,
532- false
529+ timeout
533530 );
534531
535532 ActionListener <InferModelAction .Response > mlResultsListener = listener .delegateFailureAndWrap (
@@ -557,8 +554,7 @@ public void inferRerank(
557554 new TextSimilarityConfigUpdate (query ),
558555 inputs ,
559556 inputType ,
560- timeout ,
561- false
557+ timeout
562558 );
563559
564560 var modelSettings = (CustomElandRerankTaskSettings ) model .getTaskSettings ();
@@ -610,52 +606,80 @@ public void chunkedInfer(
610606
611607 if (model instanceof ElasticsearchInternalModel esModel ) {
612608
613- var configUpdate = chunkingOptions != null
614- ? new TokenizationConfigUpdate (chunkingOptions .windowSize (), chunkingOptions .span ())
615- : new TokenizationConfigUpdate (null , null );
616-
617- var request = buildInferenceRequest (
618- model .getConfigurations ().getInferenceEntityId (),
619- configUpdate ,
609+ var batchedRequests = new EmbeddingRequestChunker (
620610 input ,
621- inputType ,
622- timeout ,
623- true
624- );
611+ EMBEDDING_MAX_BATCH_SIZE ,
612+ embeddingTypeFromTaskTypeAndSettings (model .getTaskType (), esModel .internalServiceSettings )
613+ ).batchRequestsWithListeners (listener );
614+
615+ for (var batch : batchedRequests ) {
616+ var inferenceRequest = buildInferenceRequest (
617+ model .getConfigurations ().getInferenceEntityId (),
618+ EmptyConfigUpdate .INSTANCE ,
619+ batch .batch ().inputs (),
620+ inputType ,
621+ timeout
622+ );
625623
626- ActionListener <InferModelAction .Response > mlResultsListener = listener .delegateFailureAndWrap (
627- (l , inferenceResult ) -> l .onResponse (translateToChunkedResults (inferenceResult .getInferenceResults ()))
628- );
624+ ActionListener <InferModelAction .Response > mlResultsListener = batch .listener ()
625+ .delegateFailureAndWrap (
626+ (l , inferenceResult ) -> translateToChunkedResult (model .getTaskType (), inferenceResult .getInferenceResults (), l )
627+ );
629628
630- var maybeDeployListener = mlResultsListener .delegateResponse (
631- (l , exception ) -> maybeStartDeployment (esModel , exception , request , mlResultsListener )
632- );
629+ var maybeDeployListener = mlResultsListener .delegateResponse (
630+ (l , exception ) -> maybeStartDeployment (esModel , exception , inferenceRequest , mlResultsListener )
631+ );
633632
634- client .execute (InferModelAction .INSTANCE , request , maybeDeployListener );
633+ client .execute (InferModelAction .INSTANCE , inferenceRequest , maybeDeployListener );
634+ }
635635 } else {
636636 listener .onFailure (notElasticsearchModelException (model ));
637637 }
638638 }
639639
640- private static List <ChunkedInferenceServiceResults > translateToChunkedResults (List <InferenceResults > inferenceResults ) {
641- var translated = new ArrayList <ChunkedInferenceServiceResults >();
642-
643- for (var inferenceResult : inferenceResults ) {
644- translated .add (translateToChunkedResult (inferenceResult ));
645- }
646-
647- return translated ;
648- }
640+ private static void translateToChunkedResult (
641+ TaskType taskType ,
642+ List <InferenceResults > inferenceResults ,
643+ ActionListener <InferenceServiceResults > chunkPartListener
644+ ) {
645+ if (taskType == TaskType .TEXT_EMBEDDING ) {
646+ var translated = new ArrayList <InferenceTextEmbeddingFloatResults .InferenceFloatEmbedding >();
649647
650- private static ChunkedInferenceServiceResults translateToChunkedResult (InferenceResults inferenceResult ) {
651- if (inferenceResult instanceof MlChunkedTextEmbeddingFloatResults mlChunkedResult ) {
652- return InferenceChunkedTextEmbeddingFloatResults .ofMlResults (mlChunkedResult );
653- } else if (inferenceResult instanceof MlChunkedTextExpansionResults mlChunkedResult ) {
654- return InferenceChunkedSparseEmbeddingResults .ofMlResult (mlChunkedResult );
655- } else if (inferenceResult instanceof ErrorInferenceResults error ) {
656- return new ErrorChunkedInferenceResults (error .getException ());
657- } else {
658- throw createInvalidChunkedResultException (MlChunkedTextEmbeddingFloatResults .NAME , inferenceResult .getWriteableName ());
648+ for (var inferenceResult : inferenceResults ) {
649+ if (inferenceResult instanceof MlTextEmbeddingResults mlTextEmbeddingResult ) {
650+ translated .add (
651+ new InferenceTextEmbeddingFloatResults .InferenceFloatEmbedding (mlTextEmbeddingResult .getInferenceAsFloat ())
652+ );
653+ } else if (inferenceResult instanceof ErrorInferenceResults error ) {
654+ chunkPartListener .onFailure (error .getException ());
655+ return ;
656+ } else {
657+ chunkPartListener .onFailure (
658+ createInvalidChunkedResultException (MlTextEmbeddingResults .NAME , inferenceResult .getWriteableName ())
659+ );
660+ return ;
661+ }
662+ }
663+ chunkPartListener .onResponse (new InferenceTextEmbeddingFloatResults (translated ));
664+ } else { // sparse
665+ var translated = new ArrayList <SparseEmbeddingResults .Embedding >();
666+
667+ for (var inferenceResult : inferenceResults ) {
668+ if (inferenceResult instanceof TextExpansionResults textExpansionResult ) {
669+ translated .add (
670+ new SparseEmbeddingResults .Embedding (textExpansionResult .getWeightedTokens (), textExpansionResult .isTruncated ())
671+ );
672+ } else if (inferenceResult instanceof ErrorInferenceResults error ) {
673+ chunkPartListener .onFailure (error .getException ());
674+ return ;
675+ } else {
676+ chunkPartListener .onFailure (
677+ createInvalidChunkedResultException (TextExpansionResults .NAME , inferenceResult .getWriteableName ())
678+ );
679+ return ;
680+ }
681+ }
682+ chunkPartListener .onResponse (new SparseEmbeddingResults (translated ));
659683 }
660684 }
661685
@@ -738,4 +762,21 @@ public List<UnparsedModel> defaultConfigs() {
738762 protected boolean isDefaultId (String inferenceId ) {
739763 return DEFAULT_ELSER_ID .equals (inferenceId );
740764 }
765+
766+ static EmbeddingRequestChunker .EmbeddingType embeddingTypeFromTaskTypeAndSettings (
767+ TaskType taskType ,
768+ ElasticsearchInternalServiceSettings serviceSettings
769+ ) {
770+ return switch (taskType ) {
771+ case SPARSE_EMBEDDING -> EmbeddingRequestChunker .EmbeddingType .SPARSE ;
772+ case TEXT_EMBEDDING -> serviceSettings .elementType () == null
773+ ? EmbeddingRequestChunker .EmbeddingType .FLOAT
774+ : EmbeddingRequestChunker .EmbeddingType .fromDenseVectorElementType (serviceSettings .elementType ());
775+ default -> throw new ElasticsearchStatusException (
776+ "Chunking is not supported for task type [{}]" ,
777+ RestStatus .BAD_REQUEST ,
778+ taskType
779+ );
780+ };
781+ }
741782}
0 commit comments