@@ -68,7 +68,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El
6868    private  final  EmbeddingType  embeddingType ;
6969    private  final  ChunkingSettings  chunkingSettings ;
7070
71-     private  List <List < String >>  chunkedInputs ;
71+     private  List <ChunkOffsetsAndInput >  chunkedOffsets ;
7272    private  List <AtomicArray <List <InferenceTextEmbeddingFloatResults .InferenceFloatEmbedding >>> floatResults ;
7373    private  List <AtomicArray <List <InferenceTextEmbeddingByteResults .InferenceByteEmbedding >>> byteResults ;
7474    private  List <AtomicArray <List <SparseEmbeddingResults .Embedding >>> sparseResults ;
@@ -109,7 +109,7 @@ public EmbeddingRequestChunker(
109109    }
110110
111111    private  void  splitIntoBatchedRequests (List <String > inputs ) {
112-         Function <String , List <String >> chunkFunction ;
112+         Function <String , List <Chunker . ChunkOffset >> chunkFunction ;
113113        if  (chunkingSettings  != null ) {
114114            var  chunker  = ChunkerBuilder .fromChunkingStrategy (chunkingSettings .getChunkingStrategy ());
115115            chunkFunction  = input  -> chunker .chunk (input , chunkingSettings );
@@ -118,7 +118,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
118118            chunkFunction  = input  -> chunker .chunk (input , wordsPerChunk , chunkOverlap );
119119        }
120120
121-         chunkedInputs  = new  ArrayList <>(inputs .size ());
121+         chunkedOffsets  = new  ArrayList <>(inputs .size ());
122122        switch  (embeddingType ) {
123123            case  FLOAT  -> floatResults  = new  ArrayList <>(inputs .size ());
124124            case  BYTE  -> byteResults  = new  ArrayList <>(inputs .size ());
@@ -128,18 +128,19 @@ private void splitIntoBatchedRequests(List<String> inputs) {
128128
129129        for  (int  i  = 0 ; i  < inputs .size (); i ++) {
130130            var  chunks  = chunkFunction .apply (inputs .get (i ));
131-             int  numberOfSubBatches  = addToBatches (chunks , i );
131+             var  offSetsAndInput  = new  ChunkOffsetsAndInput (chunks , inputs .get (i ));
132+             int  numberOfSubBatches  = addToBatches (offSetsAndInput , i );
132133            // size the results array with the expected number of request/responses 
133134            switch  (embeddingType ) {
134135                case  FLOAT  -> floatResults .add (new  AtomicArray <>(numberOfSubBatches ));
135136                case  BYTE  -> byteResults .add (new  AtomicArray <>(numberOfSubBatches ));
136137                case  SPARSE  -> sparseResults .add (new  AtomicArray <>(numberOfSubBatches ));
137138            }
138-             chunkedInputs .add (chunks );
139+             chunkedOffsets .add (offSetsAndInput );
139140        }
140141    }
141142
142-     private  int  addToBatches (List < String >  chunks , int  inputIndex ) {
143+     private  int  addToBatches (ChunkOffsetsAndInput   chunk , int  inputIndex ) {
143144        BatchRequest  lastBatch ;
144145        if  (batchedRequests .isEmpty ()) {
145146            lastBatch  = new  BatchRequest (new  ArrayList <>());
@@ -157,16 +158,24 @@ private int addToBatches(List<String> chunks, int inputIndex) {
157158
158159        if  (freeSpace  > 0 ) {
159160            // use any free space in the previous batch before creating new batches 
160-             int  toAdd  = Math .min (freeSpace , chunks .size ());
161-             lastBatch .addSubBatch (new  SubBatch (chunks .subList (0 , toAdd ), new  SubBatchPositionsAndCount (inputIndex , chunkIndex ++, toAdd )));
161+             int  toAdd  = Math .min (freeSpace , chunk .offsets ().size ());
162+             lastBatch .addSubBatch (
163+                 new  SubBatch (
164+                     new  ChunkOffsetsAndInput (chunk .offsets ().subList (0 , toAdd ), chunk .input ()),
165+                     new  SubBatchPositionsAndCount (inputIndex , chunkIndex ++, toAdd )
166+                 )
167+             );
162168        }
163169
164170        int  start  = freeSpace ;
165-         while  (start  < chunks .size ()) {
166-             int  toAdd  = Math .min (maxNumberOfInputsPerBatch , chunks .size () - start );
171+         while  (start  < chunk . offsets () .size ()) {
172+             int  toAdd  = Math .min (maxNumberOfInputsPerBatch , chunk . offsets () .size () - start );
167173            var  batch  = new  BatchRequest (new  ArrayList <>());
168174            batch .addSubBatch (
169-                 new  SubBatch (chunks .subList (start , start  + toAdd ), new  SubBatchPositionsAndCount (inputIndex , chunkIndex ++, toAdd ))
175+                 new  SubBatch (
176+                     new  ChunkOffsetsAndInput (chunk .offsets ().subList (start , start  + toAdd ), chunk .input ()),
177+                     new  SubBatchPositionsAndCount (inputIndex , chunkIndex ++, toAdd )
178+                 )
170179            );
171180            batchedRequests .add (batch );
172181            start  += toAdd ;
@@ -333,8 +342,8 @@ public void onFailure(Exception e) {
333342        }
334343
335344        private  void  sendResponse () {
336-             var  response  = new  ArrayList <ChunkedInferenceServiceResults >(chunkedInputs .size ());
337-             for  (int  i  = 0 ; i  < chunkedInputs .size (); i ++) {
345+             var  response  = new  ArrayList <ChunkedInferenceServiceResults >(chunkedOffsets .size ());
346+             for  (int  i  = 0 ; i  < chunkedOffsets .size (); i ++) {
338347                if  (errors .get (i ) != null ) {
339348                    response .add (errors .get (i ));
340349                } else  {
@@ -348,9 +357,9 @@ private void sendResponse() {
348357
349358    private  ChunkedInferenceServiceResults  mergeResultsWithInputs (int  resultIndex ) {
350359        return  switch  (embeddingType ) {
351-             case  FLOAT  -> mergeFloatResultsWithInputs (chunkedInputs .get (resultIndex ), floatResults .get (resultIndex ));
352-             case  BYTE  -> mergeByteResultsWithInputs (chunkedInputs .get (resultIndex ), byteResults .get (resultIndex ));
353-             case  SPARSE  -> mergeSparseResultsWithInputs (chunkedInputs .get (resultIndex ), sparseResults .get (resultIndex ));
360+             case  FLOAT  -> mergeFloatResultsWithInputs (chunkedOffsets .get (resultIndex ). toChunkText ( ), floatResults .get (resultIndex ));
361+             case  BYTE  -> mergeByteResultsWithInputs (chunkedOffsets .get (resultIndex ). toChunkText ( ), byteResults .get (resultIndex ));
362+             case  SPARSE  -> mergeSparseResultsWithInputs (chunkedOffsets .get (resultIndex ). toChunkText ( ), sparseResults .get (resultIndex ));
354363        };
355364    }
356365
@@ -428,7 +437,7 @@ public void addSubBatch(SubBatch sb) {
428437        }
429438
430439        public  List <String > inputs () {
431-             return  subBatches .stream ().flatMap (s  -> s .requests ().stream ()).collect (Collectors .toList ());
440+             return  subBatches .stream ().flatMap (s  -> s .requests ().toChunkText (). stream ()).collect (Collectors .toList ());
432441        }
433442    }
434443
@@ -441,9 +450,15 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
441450     */ 
442451    record  SubBatchPositionsAndCount (int  inputIndex , int  chunkIndex , int  embeddingCount ) {}
443452
444-     record  SubBatch (List <String > requests , SubBatchPositionsAndCount  positions ) {
445-         public  int  size () {
446-             return  requests .size ();
453+     record  SubBatch (ChunkOffsetsAndInput  requests , SubBatchPositionsAndCount  positions ) {
454+         int  size () {
455+             return  requests .offsets ().size ();
456+         }
457+     }
458+ 
459+     record  ChunkOffsetsAndInput (List <Chunker .ChunkOffset > offsets , String  input ) {
460+         List <String > toChunkText () {
461+             return  offsets .stream ().map (o  -> input .substring (o .start (), o .end ())).collect (Collectors .toList ());
447462        }
448463    }
449464}
0 commit comments