Skip to content

Commit 88a724a

Browse files
authored
[ML] Refactor the Chunker classes to return offsets (#117977) (#118279)
1 parent cb5a24d commit 88a724a

File tree

7 files changed

+119
-76
lines changed

7 files changed

+119
-76
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,7 @@
1212
import java.util.List;
1313

1414
public interface Chunker {
15-
List<String> chunk(String input, ChunkingSettings chunkingSettings);
15+
record ChunkOffset(int start, int end) {};
16+
17+
List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings);
1618
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El
6868
private final EmbeddingType embeddingType;
6969
private final ChunkingSettings chunkingSettings;
7070

71-
private List<List<String>> chunkedInputs;
71+
private List<ChunkOffsetsAndInput> chunkedOffsets;
7272
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
7373
private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
7474
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
@@ -109,7 +109,7 @@ public EmbeddingRequestChunker(
109109
}
110110

111111
private void splitIntoBatchedRequests(List<String> inputs) {
112-
Function<String, List<String>> chunkFunction;
112+
Function<String, List<Chunker.ChunkOffset>> chunkFunction;
113113
if (chunkingSettings != null) {
114114
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
115115
chunkFunction = input -> chunker.chunk(input, chunkingSettings);
@@ -118,7 +118,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
118118
chunkFunction = input -> chunker.chunk(input, wordsPerChunk, chunkOverlap);
119119
}
120120

121-
chunkedInputs = new ArrayList<>(inputs.size());
121+
chunkedOffsets = new ArrayList<>(inputs.size());
122122
switch (embeddingType) {
123123
case FLOAT -> floatResults = new ArrayList<>(inputs.size());
124124
case BYTE -> byteResults = new ArrayList<>(inputs.size());
@@ -128,18 +128,19 @@ private void splitIntoBatchedRequests(List<String> inputs) {
128128

129129
for (int i = 0; i < inputs.size(); i++) {
130130
var chunks = chunkFunction.apply(inputs.get(i));
131-
int numberOfSubBatches = addToBatches(chunks, i);
131+
var offSetsAndInput = new ChunkOffsetsAndInput(chunks, inputs.get(i));
132+
int numberOfSubBatches = addToBatches(offSetsAndInput, i);
132133
// size the results array with the expected number of request/responses
133134
switch (embeddingType) {
134135
case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches));
135136
case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches));
136137
case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches));
137138
}
138-
chunkedInputs.add(chunks);
139+
chunkedOffsets.add(offSetsAndInput);
139140
}
140141
}
141142

142-
private int addToBatches(List<String> chunks, int inputIndex) {
143+
private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) {
143144
BatchRequest lastBatch;
144145
if (batchedRequests.isEmpty()) {
145146
lastBatch = new BatchRequest(new ArrayList<>());
@@ -157,16 +158,24 @@ private int addToBatches(List<String> chunks, int inputIndex) {
157158

158159
if (freeSpace > 0) {
159160
// use any free space in the previous batch before creating new batches
160-
int toAdd = Math.min(freeSpace, chunks.size());
161-
lastBatch.addSubBatch(new SubBatch(chunks.subList(0, toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)));
161+
int toAdd = Math.min(freeSpace, chunk.offsets().size());
162+
lastBatch.addSubBatch(
163+
new SubBatch(
164+
new ChunkOffsetsAndInput(chunk.offsets().subList(0, toAdd), chunk.input()),
165+
new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)
166+
)
167+
);
162168
}
163169

164170
int start = freeSpace;
165-
while (start < chunks.size()) {
166-
int toAdd = Math.min(maxNumberOfInputsPerBatch, chunks.size() - start);
171+
while (start < chunk.offsets().size()) {
172+
int toAdd = Math.min(maxNumberOfInputsPerBatch, chunk.offsets().size() - start);
167173
var batch = new BatchRequest(new ArrayList<>());
168174
batch.addSubBatch(
169-
new SubBatch(chunks.subList(start, start + toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd))
175+
new SubBatch(
176+
new ChunkOffsetsAndInput(chunk.offsets().subList(start, start + toAdd), chunk.input()),
177+
new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)
178+
)
170179
);
171180
batchedRequests.add(batch);
172181
start += toAdd;
@@ -333,8 +342,8 @@ public void onFailure(Exception e) {
333342
}
334343

335344
private void sendResponse() {
336-
var response = new ArrayList<ChunkedInferenceServiceResults>(chunkedInputs.size());
337-
for (int i = 0; i < chunkedInputs.size(); i++) {
345+
var response = new ArrayList<ChunkedInferenceServiceResults>(chunkedOffsets.size());
346+
for (int i = 0; i < chunkedOffsets.size(); i++) {
338347
if (errors.get(i) != null) {
339348
response.add(errors.get(i));
340349
} else {
@@ -348,9 +357,9 @@ private void sendResponse() {
348357

349358
private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) {
350359
return switch (embeddingType) {
351-
case FLOAT -> mergeFloatResultsWithInputs(chunkedInputs.get(resultIndex), floatResults.get(resultIndex));
352-
case BYTE -> mergeByteResultsWithInputs(chunkedInputs.get(resultIndex), byteResults.get(resultIndex));
353-
case SPARSE -> mergeSparseResultsWithInputs(chunkedInputs.get(resultIndex), sparseResults.get(resultIndex));
360+
case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), floatResults.get(resultIndex));
361+
case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), byteResults.get(resultIndex));
362+
case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), sparseResults.get(resultIndex));
354363
};
355364
}
356365

@@ -428,7 +437,7 @@ public void addSubBatch(SubBatch sb) {
428437
}
429438

430439
public List<String> inputs() {
431-
return subBatches.stream().flatMap(s -> s.requests().stream()).collect(Collectors.toList());
440+
return subBatches.stream().flatMap(s -> s.requests().toChunkText().stream()).collect(Collectors.toList());
432441
}
433442
}
434443

@@ -441,9 +450,15 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
441450
*/
442451
record SubBatchPositionsAndCount(int inputIndex, int chunkIndex, int embeddingCount) {}
443452

444-
record SubBatch(List<String> requests, SubBatchPositionsAndCount positions) {
445-
public int size() {
446-
return requests.size();
453+
record SubBatch(ChunkOffsetsAndInput requests, SubBatchPositionsAndCount positions) {
454+
int size() {
455+
return requests.offsets().size();
456+
}
457+
}
458+
459+
record ChunkOffsetsAndInput(List<Chunker.ChunkOffset> offsets, String input) {
460+
List<String> toChunkText() {
461+
return offsets.stream().map(o -> input.substring(o.start(), o.end())).collect(Collectors.toList());
447462
}
448463
}
449464
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ public class SentenceBoundaryChunker implements Chunker {
3434
public SentenceBoundaryChunker() {
3535
sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT);
3636
wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
37-
3837
}
3938

4039
/**
@@ -45,7 +44,7 @@ public SentenceBoundaryChunker() {
4544
* @return The input text chunked
4645
*/
4746
@Override
48-
public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
47+
public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
4948
if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) {
5049
return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize, sentenceBoundaryChunkingSettings.sentenceOverlap > 0);
5150
} else {
@@ -65,8 +64,8 @@ public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
6564
* @param maxNumberWordsPerChunk Maximum size of the chunk
6665
* @return The input text chunked
6766
*/
68-
public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
69-
var chunks = new ArrayList<String>();
67+
public List<ChunkOffset> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
68+
var chunks = new ArrayList<ChunkOffset>();
7069

7170
sentenceIterator.setText(input);
7271
wordIterator.setText(input);
@@ -91,7 +90,7 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
9190
int nextChunkWordCount = wordsInSentenceCount;
9291
if (chunkWordCount > 0) {
9392
// add a new chunk containing all the input up to this sentence
94-
chunks.add(input.substring(chunkStart, chunkEnd));
93+
chunks.add(new ChunkOffset(chunkStart, chunkEnd));
9594

9695
if (includePrecedingSentence) {
9796
if (wordsInPrecedingSentenceCount + wordsInSentenceCount > maxNumberWordsPerChunk) {
@@ -127,12 +126,17 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
127126
for (; i < sentenceSplits.size() - 1; i++) {
128127
// Because the substring was passed to splitLongSentence()
129128
// the returned positions need to be offset by chunkStart
130-
chunks.add(input.substring(chunkStart + sentenceSplits.get(i).start(), chunkStart + sentenceSplits.get(i).end()));
129+
chunks.add(
130+
new ChunkOffset(
131+
chunkStart + sentenceSplits.get(i).offsets().start(),
132+
chunkStart + sentenceSplits.get(i).offsets().end()
133+
)
134+
);
131135
}
132136
// The final split is partially filled.
133137
// Set the next chunk start to the beginning of the
134138
// final split of the long sentence.
135-
chunkStart = chunkStart + sentenceSplits.get(i).start(); // start pos needs to be offset by chunkStart
139+
chunkStart = chunkStart + sentenceSplits.get(i).offsets().start(); // start pos needs to be offset by chunkStart
136140
chunkWordCount = sentenceSplits.get(i).wordCount();
137141
}
138142
} else {
@@ -151,7 +155,7 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
151155
}
152156

153157
if (chunkWordCount > 0) {
154-
chunks.add(input.substring(chunkStart));
158+
chunks.add(new ChunkOffset(chunkStart, input.length()));
155159
}
156160

157161
return chunks;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import java.util.ArrayList;
1616
import java.util.List;
1717
import java.util.Locale;
18+
import java.util.stream.Collectors;
1819

1920
/**
2021
* Breaks text into smaller strings or chunks on Word boundaries.
@@ -35,7 +36,7 @@ public WordBoundaryChunker() {
3536
wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
3637
}
3738

38-
record ChunkPosition(int start, int end, int wordCount) {}
39+
record ChunkPosition(ChunkOffset offsets, int wordCount) {}
3940

4041
/**
4142
* Break the input text into small chunks as dictated
@@ -45,7 +46,7 @@ record ChunkPosition(int start, int end, int wordCount) {}
4546
* @return List of chunked text
4647
*/
4748
@Override
48-
public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
49+
public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
4950
if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) {
5051
return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap);
5152
} else {
@@ -64,18 +65,9 @@ public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
6465
* Can be 0 but must be non-negative.
6566
* @return List of chunked text
6667
*/
67-
public List<String> chunk(String input, int chunkSize, int overlap) {
68-
69-
if (input.isEmpty()) {
70-
return List.of("");
71-
}
72-
68+
public List<ChunkOffset> chunk(String input, int chunkSize, int overlap) {
7369
var chunkPositions = chunkPositions(input, chunkSize, overlap);
74-
var chunks = new ArrayList<String>(chunkPositions.size());
75-
for (var pos : chunkPositions) {
76-
chunks.add(input.substring(pos.start, pos.end));
77-
}
78-
return chunks;
70+
return chunkPositions.stream().map(ChunkPosition::offsets).collect(Collectors.toList());
7971
}
8072

8173
/**
@@ -127,7 +119,7 @@ List<ChunkPosition> chunkPositions(String input, int chunkSize, int overlap) {
127119
wordsSinceStartWindowWasMarked++;
128120

129121
if (wordsInChunkCountIncludingOverlap >= chunkSize) {
130-
chunkPositions.add(new ChunkPosition(windowStart, boundary, wordsInChunkCountIncludingOverlap));
122+
chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, boundary), wordsInChunkCountIncludingOverlap));
131123
wordsInChunkCountIncludingOverlap = overlap;
132124

133125
if (overlap == 0) {
@@ -149,7 +141,7 @@ List<ChunkPosition> chunkPositions(String input, int chunkSize, int overlap) {
149141
// if it ends on a boundary than the count should equal overlap in which case
150142
// we can ignore it, unless this is the first chunk in which case we want to add it
151143
if (wordsInChunkCountIncludingOverlap > overlap || chunkPositions.isEmpty()) {
152-
chunkPositions.add(new ChunkPosition(windowStart, input.length(), wordsInChunkCountIncludingOverlap));
144+
chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, input.length()), wordsInChunkCountIncludingOverlap));
153145
}
154146

155147
return chunkPositions;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public void testMultipleShortInputsAreSingleBatch() {
6262
var subBatches = batches.get(0).batch().subBatches();
6363
for (int i = 0; i < inputs.size(); i++) {
6464
var subBatch = subBatches.get(i);
65-
assertThat(subBatch.requests(), contains(inputs.get(i)));
65+
assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i)));
6666
assertEquals(0, subBatch.positions().chunkIndex());
6767
assertEquals(i, subBatch.positions().inputIndex());
6868
assertEquals(1, subBatch.positions().embeddingCount());
@@ -102,7 +102,7 @@ public void testManyInputsMakeManyBatches() {
102102
var subBatches = batches.get(0).batch().subBatches();
103103
for (int i = 0; i < batches.size(); i++) {
104104
var subBatch = subBatches.get(i);
105-
assertThat(subBatch.requests(), contains(inputs.get(i)));
105+
assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i)));
106106
assertEquals(0, subBatch.positions().chunkIndex());
107107
assertEquals(inputIndex, subBatch.positions().inputIndex());
108108
assertEquals(1, subBatch.positions().embeddingCount());
@@ -146,7 +146,7 @@ public void testChunkingSettingsProvided() {
146146
var subBatches = batches.get(0).batch().subBatches();
147147
for (int i = 0; i < batches.size(); i++) {
148148
var subBatch = subBatches.get(i);
149-
assertThat(subBatch.requests(), contains(inputs.get(i)));
149+
assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i)));
150150
assertEquals(0, subBatch.positions().chunkIndex());
151151
assertEquals(inputIndex, subBatch.positions().inputIndex());
152152
assertEquals(1, subBatch.positions().embeddingCount());
@@ -184,17 +184,17 @@ public void testLongInputChunkedOverMultipleBatches() {
184184
assertEquals(0, subBatch.positions().inputIndex());
185185
assertEquals(0, subBatch.positions().chunkIndex());
186186
assertEquals(1, subBatch.positions().embeddingCount());
187-
assertThat(subBatch.requests(), contains("1st small"));
187+
assertThat(subBatch.requests().toChunkText(), contains("1st small"));
188188
}
189189
{
190190
var subBatch = batch.subBatches().get(1);
191191
assertEquals(1, subBatch.positions().inputIndex()); // 2nd input
192192
assertEquals(0, subBatch.positions().chunkIndex()); // 1st part of the 2nd input
193193
assertEquals(4, subBatch.positions().embeddingCount()); // 4 chunks
194-
assertThat(subBatch.requests().get(0), startsWith("passage_input0 "));
195-
assertThat(subBatch.requests().get(1), startsWith(" passage_input20 "));
196-
assertThat(subBatch.requests().get(2), startsWith(" passage_input40 "));
197-
assertThat(subBatch.requests().get(3), startsWith(" passage_input60 "));
194+
assertThat(subBatch.requests().toChunkText().get(0), startsWith("passage_input0 "));
195+
assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input20 "));
196+
assertThat(subBatch.requests().toChunkText().get(2), startsWith(" passage_input40 "));
197+
assertThat(subBatch.requests().toChunkText().get(3), startsWith(" passage_input60 "));
198198
}
199199
}
200200
{
@@ -207,22 +207,22 @@ public void testLongInputChunkedOverMultipleBatches() {
207207
assertEquals(1, subBatch.positions().inputIndex()); // 2nd input
208208
assertEquals(1, subBatch.positions().chunkIndex()); // 2nd part of the 2nd input
209209
assertEquals(2, subBatch.positions().embeddingCount());
210-
assertThat(subBatch.requests().get(0), startsWith(" passage_input80 "));
211-
assertThat(subBatch.requests().get(1), startsWith(" passage_input100 "));
210+
assertThat(subBatch.requests().toChunkText().get(0), startsWith(" passage_input80 "));
211+
assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input100 "));
212212
}
213213
{
214214
var subBatch = batch.subBatches().get(1);
215215
assertEquals(2, subBatch.positions().inputIndex()); // 3rd input
216216
assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part
217217
assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk
218-
assertThat(subBatch.requests(), contains("2nd small"));
218+
assertThat(subBatch.requests().toChunkText(), contains("2nd small"));
219219
}
220220
{
221221
var subBatch = batch.subBatches().get(2);
222222
assertEquals(3, subBatch.positions().inputIndex()); // 4th input
223223
assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part
224224
assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk
225-
assertThat(subBatch.requests(), contains("3rd small"));
225+
assertThat(subBatch.requests().toChunkText(), contains("3rd small"));
226226
}
227227
}
228228
}

0 commit comments

Comments
 (0)