@@ -68,7 +68,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El
6868 private final EmbeddingType embeddingType ;
6969 private final ChunkingSettings chunkingSettings ;
7070
71- private List <List < String >> chunkedInputs ;
71+ private List <ChunkOffsetsAndInput > chunkedOffsets ;
7272 private List <AtomicArray <List <InferenceTextEmbeddingFloatResults .InferenceFloatEmbedding >>> floatResults ;
7373 private List <AtomicArray <List <InferenceTextEmbeddingByteResults .InferenceByteEmbedding >>> byteResults ;
7474 private List <AtomicArray <List <SparseEmbeddingResults .Embedding >>> sparseResults ;
@@ -109,7 +109,7 @@ public EmbeddingRequestChunker(
109109 }
110110
111111 private void splitIntoBatchedRequests (List <String > inputs ) {
112- Function <String , List <String >> chunkFunction ;
112+ Function <String , List <Chunker . ChunkOffset >> chunkFunction ;
113113 if (chunkingSettings != null ) {
114114 var chunker = ChunkerBuilder .fromChunkingStrategy (chunkingSettings .getChunkingStrategy ());
115115 chunkFunction = input -> chunker .chunk (input , chunkingSettings );
@@ -118,7 +118,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
118118 chunkFunction = input -> chunker .chunk (input , wordsPerChunk , chunkOverlap );
119119 }
120120
121- chunkedInputs = new ArrayList <>(inputs .size ());
121+ chunkedOffsets = new ArrayList <>(inputs .size ());
122122 switch (embeddingType ) {
123123 case FLOAT -> floatResults = new ArrayList <>(inputs .size ());
124124 case BYTE -> byteResults = new ArrayList <>(inputs .size ());
@@ -128,18 +128,19 @@ private void splitIntoBatchedRequests(List<String> inputs) {
128128
129129 for (int i = 0 ; i < inputs .size (); i ++) {
130130 var chunks = chunkFunction .apply (inputs .get (i ));
131- int numberOfSubBatches = addToBatches (chunks , i );
131+ var offSetsAndInput = new ChunkOffsetsAndInput (chunks , inputs .get (i ));
132+ int numberOfSubBatches = addToBatches (offSetsAndInput , i );
132133 // size the results array with the expected number of request/responses
133134 switch (embeddingType ) {
134135 case FLOAT -> floatResults .add (new AtomicArray <>(numberOfSubBatches ));
135136 case BYTE -> byteResults .add (new AtomicArray <>(numberOfSubBatches ));
136137 case SPARSE -> sparseResults .add (new AtomicArray <>(numberOfSubBatches ));
137138 }
138- chunkedInputs .add (chunks );
139+ chunkedOffsets .add (offSetsAndInput );
139140 }
140141 }
141142
142- private int addToBatches (List < String > chunks , int inputIndex ) {
143+ private int addToBatches (ChunkOffsetsAndInput chunk , int inputIndex ) {
143144 BatchRequest lastBatch ;
144145 if (batchedRequests .isEmpty ()) {
145146 lastBatch = new BatchRequest (new ArrayList <>());
@@ -157,16 +158,24 @@ private int addToBatches(List<String> chunks, int inputIndex) {
157158
158159 if (freeSpace > 0 ) {
159160 // use any free space in the previous batch before creating new batches
160- int toAdd = Math .min (freeSpace , chunks .size ());
161- lastBatch .addSubBatch (new SubBatch (chunks .subList (0 , toAdd ), new SubBatchPositionsAndCount (inputIndex , chunkIndex ++, toAdd )));
161+ int toAdd = Math .min (freeSpace , chunk .offsets ().size ());
162+ lastBatch .addSubBatch (
163+ new SubBatch (
164+ new ChunkOffsetsAndInput (chunk .offsets ().subList (0 , toAdd ), chunk .input ()),
165+ new SubBatchPositionsAndCount (inputIndex , chunkIndex ++, toAdd )
166+ )
167+ );
162168 }
163169
164170 int start = freeSpace ;
165- while (start < chunks .size ()) {
166- int toAdd = Math .min (maxNumberOfInputsPerBatch , chunks .size () - start );
171+ while (start < chunk . offsets () .size ()) {
172+ int toAdd = Math .min (maxNumberOfInputsPerBatch , chunk . offsets () .size () - start );
167173 var batch = new BatchRequest (new ArrayList <>());
168174 batch .addSubBatch (
169- new SubBatch (chunks .subList (start , start + toAdd ), new SubBatchPositionsAndCount (inputIndex , chunkIndex ++, toAdd ))
175+ new SubBatch (
176+ new ChunkOffsetsAndInput (chunk .offsets ().subList (start , start + toAdd ), chunk .input ()),
177+ new SubBatchPositionsAndCount (inputIndex , chunkIndex ++, toAdd )
178+ )
170179 );
171180 batchedRequests .add (batch );
172181 start += toAdd ;
@@ -333,8 +342,8 @@ public void onFailure(Exception e) {
333342 }
334343
335344 private void sendResponse () {
336- var response = new ArrayList <ChunkedInferenceServiceResults >(chunkedInputs .size ());
337- for (int i = 0 ; i < chunkedInputs .size (); i ++) {
345+ var response = new ArrayList <ChunkedInferenceServiceResults >(chunkedOffsets .size ());
346+ for (int i = 0 ; i < chunkedOffsets .size (); i ++) {
338347 if (errors .get (i ) != null ) {
339348 response .add (errors .get (i ));
340349 } else {
@@ -348,9 +357,9 @@ private void sendResponse() {
348357
349358 private ChunkedInferenceServiceResults mergeResultsWithInputs (int resultIndex ) {
350359 return switch (embeddingType ) {
351- case FLOAT -> mergeFloatResultsWithInputs (chunkedInputs .get (resultIndex ), floatResults .get (resultIndex ));
352- case BYTE -> mergeByteResultsWithInputs (chunkedInputs .get (resultIndex ), byteResults .get (resultIndex ));
353- case SPARSE -> mergeSparseResultsWithInputs (chunkedInputs .get (resultIndex ), sparseResults .get (resultIndex ));
360+ case FLOAT -> mergeFloatResultsWithInputs (chunkedOffsets .get (resultIndex ). toChunkText ( ), floatResults .get (resultIndex ));
361+ case BYTE -> mergeByteResultsWithInputs (chunkedOffsets .get (resultIndex ). toChunkText ( ), byteResults .get (resultIndex ));
362+ case SPARSE -> mergeSparseResultsWithInputs (chunkedOffsets .get (resultIndex ). toChunkText ( ), sparseResults .get (resultIndex ));
354363 };
355364 }
356365
@@ -428,7 +437,7 @@ public void addSubBatch(SubBatch sb) {
428437 }
429438
430439 public List <String > inputs () {
431- return subBatches .stream ().flatMap (s -> s .requests ().stream ()).collect (Collectors .toList ());
440+ return subBatches .stream ().flatMap (s -> s .requests ().toChunkText (). stream ()).collect (Collectors .toList ());
432441 }
433442 }
434443
@@ -441,9 +450,15 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
441450 */
442451 record SubBatchPositionsAndCount (int inputIndex , int chunkIndex , int embeddingCount ) {}
443452
444- record SubBatch (List <String > requests , SubBatchPositionsAndCount positions ) {
445- public int size () {
446- return requests .size ();
453+ record SubBatch (ChunkOffsetsAndInput requests , SubBatchPositionsAndCount positions ) {
454+ int size () {
455+ return requests .offsets ().size ();
456+ }
457+ }
458+
459+ record ChunkOffsetsAndInput (List <Chunker .ChunkOffset > offsets , String input ) {
460+ List <String > toChunkText () {
461+ return offsets .stream ().map (o -> input .substring (o .start (), o .end ())).collect (Collectors .toList ());
447462 }
448463 }
449464}
0 commit comments