Skip to content

Commit f40947a

Browse files
Add javadoc for chunker, add tests, reduce word counting operations
1 parent 0685124 commit f40947a

File tree

2 files changed

+65
-22
lines changed

2 files changed

+65
-22
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@
1616
import java.util.List;
1717
import java.util.regex.Pattern;
1818

19+
/**
20+
* Split text into chunks recursively based on a list of separator regex strings.
21+
* The maximum chunk size is measured in words and controlled
22+
* by {@code maxNumberWordsPerChunk}. For each separator the chunker will go through the following process:
23+
* 1. Split the text on each regex match of the separator.
24+
* 2. Merge consecutive chunks when it is possible to do so without exceeding the max chunk size.
25+
* 3. For each chunk after the merge:
26+
* 1. Return it if it is within the maximum chunk size.
27+
* 2. Repeat the process using the next separator in the list if the chunk exceeds the maximum chunk size.
28+
* If there are no more separators left to try, run the {@code SentenceBoundaryChunker} with the provided
29+
* max chunk size and no overlaps.
30+
*/
1931
public class RecursiveChunker implements Chunker {
2032
private final BreakIterator wordIterator;
2133

@@ -35,7 +47,7 @@ public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings)
3547
}
3648

3749
private List<ChunkOffset> chunk(String input, List<String> separators, int maxChunkSize, int separatorIndex, int chunkOffset) {
38-
if (input.length() < 2 || isChunkWithinMaxSize(input, new ChunkOffset(0, input.length()), maxChunkSize)) {
50+
if (input.length() < 2 || isChunkWithinMaxSize(buildChunkOffsetAndCount(input, 0, input.length()), maxChunkSize)) {
3951
return List.of(new ChunkOffset(chunkOffset, chunkOffset + input.length()));
4052
}
4153

@@ -44,22 +56,23 @@ private List<ChunkOffset> chunk(String input, List<String> separators, int maxCh
4456
}
4557

4658
var potentialChunks = mergeChunkOffsetsUpToMaxChunkSize(
47-
input,
4859
splitTextBySeparatorRegex(input, separators.get(separatorIndex)),
4960
maxChunkSize
5061
);
5162
var actualChunks = new ArrayList<ChunkOffset>();
5263
for (var potentialChunk : potentialChunks) {
53-
if (isChunkWithinMaxSize(input, potentialChunk, maxChunkSize)) {
54-
actualChunks.add(new ChunkOffset(chunkOffset + potentialChunk.start(), chunkOffset + potentialChunk.end()));
64+
if (isChunkWithinMaxSize(potentialChunk, maxChunkSize)) {
65+
actualChunks.add(
66+
new ChunkOffset(chunkOffset + potentialChunk.chunkOffset.start(), chunkOffset + potentialChunk.chunkOffset.end())
67+
);
5568
} else {
5669
actualChunks.addAll(
5770
chunk(
58-
input.substring(potentialChunk.start(), potentialChunk.end()),
71+
input.substring(potentialChunk.chunkOffset.start(), potentialChunk.chunkOffset.end()),
5972
separators,
6073
maxChunkSize,
6174
separatorIndex + 1,
62-
chunkOffset + potentialChunk.start()
75+
chunkOffset + potentialChunk.chunkOffset.start()
6376
)
6477
);
6578
}
@@ -68,55 +81,63 @@ private List<ChunkOffset> chunk(String input, List<String> separators, int maxCh
6881
return actualChunks;
6982
}
7083

71-
private boolean isChunkWithinMaxSize(String fullText, ChunkOffset chunk, int maxChunkSize) {
84+
private boolean isChunkWithinMaxSize(ChunkOffsetAndCount chunkOffsetAndCount, int maxChunkSize) {
85+
return chunkOffsetAndCount.wordCount <= maxChunkSize;
86+
}
87+
88+
private ChunkOffsetAndCount buildChunkOffsetAndCount(String fullText, int chunkStart, int chunkEnd) {
89+
var chunkOffset = new ChunkOffset(chunkStart, chunkEnd);
90+
7291
wordIterator.setText(fullText);
73-
return ChunkerUtils.countWords(chunk.start(), chunk.end(), wordIterator) <= maxChunkSize;
92+
return new ChunkOffsetAndCount(chunkOffset, ChunkerUtils.countWords(chunkStart, chunkEnd, wordIterator));
7493
}
7594

76-
private List<ChunkOffset> splitTextBySeparatorRegex(String input, String separatorRegex) {
95+
private List<ChunkOffsetAndCount> splitTextBySeparatorRegex(String input, String separatorRegex) {
7796
var pattern = Pattern.compile(separatorRegex);
7897
var matcher = pattern.matcher(input);
7998

80-
var chunkOffsets = new ArrayList<ChunkOffset>();
99+
var chunkOffsets = new ArrayList<ChunkOffsetAndCount>();
81100
int chunkStart = 0;
82-
int searchStart = 0;
83-
while (matcher.find(searchStart)) {
101+
while (matcher.find()) {
84102
var chunkEnd = matcher.start();
85103
if (chunkStart < chunkEnd) {
86-
chunkOffsets.add(new ChunkOffset(chunkStart, chunkEnd));
104+
chunkOffsets.add(buildChunkOffsetAndCount(input, chunkStart, chunkEnd));
87105
}
88106
chunkStart = matcher.start();
89-
searchStart = matcher.end();
90107
}
91108

92109
if (chunkStart < input.length()) {
93-
chunkOffsets.add(new ChunkOffset(chunkStart, input.length()));
110+
chunkOffsets.add(buildChunkOffsetAndCount(input, chunkStart, input.length()));
94111
}
95112

96113
return chunkOffsets;
97114
}
98115

99-
private List<ChunkOffset> mergeChunkOffsetsUpToMaxChunkSize(String input, List<ChunkOffset> chunkOffsets, int maxChunkSize) {
116+
private List<ChunkOffsetAndCount> mergeChunkOffsetsUpToMaxChunkSize(List<ChunkOffsetAndCount> chunkOffsets, int maxChunkSize) {
100117
if (chunkOffsets.size() < 2) {
101118
return chunkOffsets;
102119
}
103120

104-
List<ChunkOffset> mergedOffsets = new ArrayList<>();
121+
List<ChunkOffsetAndCount> mergedOffsetsAndCounts = new ArrayList<>();
105122
var mergedChunk = chunkOffsets.getFirst();
106123
for (int i = 1; i < chunkOffsets.size(); i++) {
107-
var potentialMergedChunk = new ChunkOffset(mergedChunk.start(), chunkOffsets.get(i).end());
108-
if (isChunkWithinMaxSize(input, potentialMergedChunk, maxChunkSize)) {
124+
var chunkOffsetAndCountToMerge = chunkOffsets.get(i);
125+
var potentialMergedChunk = new ChunkOffsetAndCount(
126+
new ChunkOffset(mergedChunk.chunkOffset.start(), chunkOffsetAndCountToMerge.chunkOffset.end()),
127+
mergedChunk.wordCount + chunkOffsetAndCountToMerge.wordCount
128+
);
129+
if (isChunkWithinMaxSize(potentialMergedChunk, maxChunkSize)) {
109130
mergedChunk = potentialMergedChunk;
110131
} else {
111-
mergedOffsets.add(mergedChunk);
132+
mergedOffsetsAndCounts.add(mergedChunk);
112133
mergedChunk = chunkOffsets.get(i);
113134
}
114135

115136
if (i == chunkOffsets.size() - 1) {
116-
mergedOffsets.add(mergedChunk);
137+
mergedOffsetsAndCounts.add(mergedChunk);
117138
}
118139
}
119-
return mergedOffsets;
140+
return mergedOffsetsAndCounts;
120141
}
121142

122143
private List<ChunkOffset> chunkWithBackupChunker(String input, int maxChunkSize, int chunkOffset) {
@@ -127,4 +148,6 @@ private List<ChunkOffset> chunkWithBackupChunker(String input, int maxChunkSize,
127148
}
128149
return chunksWithOffsets;
129150
}
151+
152+
private record ChunkOffsetAndCount(ChunkOffset chunkOffset, int wordCount) {}
130153
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,26 @@ public void testChunkInputRequiresBackupChunkingStrategy() {
129129
);
130130
}
131131

132+
public void testChunkWithRegexSeparator() {
133+
var separators = List.of("(?<!\\n)\\n(?!\\n)", "(?<!\\n)\\n\\n(?!\\n)");
134+
RecursiveChunkingSettings settings = generateChunkingSettings(10, separators);
135+
String input = generateTestText(4, List.of("\n", "\n", "\n\n"));
136+
137+
var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length();
138+
var expectedSecondChunkOffsetEnd = TEST_SENTENCE.length() * 2 + "\n".length();
139+
var expectedThirdChunkOffsetEnd = TEST_SENTENCE.length() * 3 + "\n".length() * 2;
140+
assertExpectedChunksGenerated(
141+
input,
142+
settings,
143+
List.of(
144+
new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd),
145+
new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, expectedSecondChunkOffsetEnd),
146+
new Chunker.ChunkOffset(expectedSecondChunkOffsetEnd, expectedThirdChunkOffsetEnd),
147+
new Chunker.ChunkOffset(expectedThirdChunkOffsetEnd, input.length())
148+
)
149+
);
150+
}
151+
132152
private void assertExpectedChunksGenerated(String input, RecursiveChunkingSettings settings, List<Chunker.ChunkOffset> expectedChunks) {
133153
RecursiveChunker chunker = new RecursiveChunker();
134154
List<Chunker.ChunkOffset> chunks = chunker.chunk(input, settings);

0 commit comments

Comments
 (0)