Skip to content

Commit 5167b21

Browse files
Add recursive chunker
1 parent 6f7a206 commit 5167b21

File tree

15 files changed

+766
-91
lines changed

15 files changed

+766
-91
lines changed

server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
public enum ChunkingStrategy {
1717
WORD("word"),
18-
SENTENCE("sentence");
18+
SENTENCE("sentence"),
19+
RECURSIVE("recursive");
1920

2021
private final String chunkingStrategy;
2122

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
2727
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
2828
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
29+
import org.elasticsearch.xpack.inference.chunking.RecursiveChunkingSettings;
2930
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
3031
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
3132
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
@@ -469,6 +470,9 @@ private static void addChunkingSettingsNamedWriteables(List<NamedWriteableRegist
469470
SentenceBoundaryChunkingSettings::new
470471
)
471472
);
473+
namedWriteables.add(
474+
new NamedWriteableRegistry.Entry(ChunkingSettings.class, RecursiveChunkingSettings.NAME, RecursiveChunkingSettings::new)
475+
);
472476
}
473477

474478
private static void addInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public static Chunker fromChunkingStrategy(ChunkingStrategy chunkingStrategy) {
1818
return switch (chunkingStrategy) {
1919
case WORD -> new WordBoundaryChunker();
2020
case SENTENCE -> new SentenceBoundaryChunker();
21+
case RECURSIVE -> new RecursiveChunker();
2122
};
2223
}
2324
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.chunking;
9+
10+
import com.ibm.icu.text.BreakIterator;
11+
12+
public class ChunkerUtils {
13+
14+
// setText() should be applied before using this function.
15+
static int countWords(int start, int end, BreakIterator wordIterator) {
16+
assert start < end;
17+
wordIterator.preceding(start); // start of the current word
18+
19+
int boundary = wordIterator.current();
20+
int wordCount = 0;
21+
while (boundary != BreakIterator.DONE && boundary <= end) {
22+
int wordStatus = wordIterator.getRuleStatus();
23+
if (wordStatus != BreakIterator.WORD_NONE) {
24+
wordCount++;
25+
}
26+
boundary = wordIterator.next();
27+
}
28+
29+
return wordCount;
30+
}
31+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public static ChunkingSettings fromMap(Map<String, Object> settings) {
3535
return switch (chunkingStrategy) {
3636
case WORD -> WordBoundaryChunkingSettings.fromMap(settings);
3737
case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(settings);
38+
case RECURSIVE -> RecursiveChunkingSettings.fromMap(settings);
3839
};
3940
}
4041
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ public enum ChunkingSettingsOptions {
1111
STRATEGY("strategy"),
1212
MAX_CHUNK_SIZE("max_chunk_size"),
1313
OVERLAP("overlap"),
14-
SENTENCE_OVERLAP("sentence_overlap");
14+
SENTENCE_OVERLAP("sentence_overlap"),
15+
SEPARATOR_SET("separator_set"),
16+
SEPARATORS("separators");
1517

1618
private final String chunkingSettingsOption;
1719

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.chunking;
9+
10+
import com.ibm.icu.text.BreakIterator;
11+
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.inference.ChunkingSettings;
14+
15+
import java.util.ArrayList;
16+
import java.util.List;
17+
import java.util.regex.Pattern;
18+
19+
public class RecursiveChunker implements Chunker {
20+
private BreakIterator wordIterator;
21+
22+
public RecursiveChunker() {
23+
wordIterator = BreakIterator.getWordInstance();
24+
}
25+
26+
@Override
27+
public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
28+
if (chunkingSettings instanceof RecursiveChunkingSettings recursiveChunkingSettings) {
29+
return chunk(input, recursiveChunkingSettings.getSeparators(), recursiveChunkingSettings.getMaxChunkSize(), 0, 0);
30+
} else {
31+
throw new IllegalArgumentException(
32+
Strings.format("RecursiveChunker can't use ChunkingSettings with strategy [%s]", chunkingSettings.getChunkingStrategy())
33+
);
34+
}
35+
}
36+
37+
private List<ChunkOffset> chunk(String input, List<String> splitters, int maxChunkSize, int splitterIndex, int chunkOffset) {
38+
if (input.length() < 2 || isChunkWithinMaxSize(input, new ChunkOffset(0, input.length()), maxChunkSize)) {
39+
return List.of(new ChunkOffset(chunkOffset, chunkOffset + input.length()));
40+
}
41+
42+
if (splitterIndex > splitters.size() - 1) {
43+
return chunkWithBackupChunker(input, maxChunkSize, chunkOffset);
44+
}
45+
46+
var potentialChunks = splitAndMergeChunks(input, splitters.get(splitterIndex), maxChunkSize);
47+
var actualChunks = new ArrayList<ChunkOffset>();
48+
for (var potentialChunk : potentialChunks) {
49+
// TODO: Decide if we want to allow the first condition? Ex. "## This is a test...." split on "#" will create
50+
// a chunk with just "#" If the rest of the sentence is bigger than the maximum chunk size. We can either stop this by
51+
// doing something like splitting on the "current splitter" but skipping anything that matches the previous splitters
52+
// Similarly we could make the splitter a regex and update the default splitters to specifically match just the value without
53+
// Duplicate values around it
54+
// Or we can merge chunks across all levels after everything is done instead of merging them after each split
55+
if (potentialChunk.start() == potentialChunk.end() || isChunkWithinMaxSize(input, potentialChunk, maxChunkSize)) {
56+
actualChunks.add(new ChunkOffset(chunkOffset + potentialChunk.start(), chunkOffset + potentialChunk.end()));
57+
} else {
58+
actualChunks.addAll(
59+
chunk(
60+
input.substring(potentialChunk.start(), potentialChunk.end()),
61+
splitters,
62+
maxChunkSize,
63+
splitterIndex + 1,
64+
chunkOffset + potentialChunk.start()
65+
)
66+
);
67+
}
68+
}
69+
70+
return actualChunks;
71+
}
72+
73+
private boolean isChunkWithinMaxSize(String fullText, ChunkOffset chunk, int maxChunkSize) {
74+
wordIterator.setText(fullText);
75+
return ChunkerUtils.countWords(chunk.start(), chunk.end(), wordIterator) <= maxChunkSize;
76+
}
77+
78+
private List<ChunkOffset> splitAndMergeChunks(String input, String separator, int maxChunkSize) {
79+
return mergeChunkOffsetsUpToMaxChunkSize(input, splitTextBySeparatorRegex(input, separator), maxChunkSize);
80+
}
81+
82+
private List<ChunkOffset> splitTextBySeparatorRegex(String input, String separatorRegex) {
83+
var pattern = Pattern.compile(separatorRegex);
84+
var matcher = pattern.matcher(input);
85+
86+
var chunkOffsets = new ArrayList<ChunkOffset>();
87+
int chunkStart = 0;
88+
int searchStart = 0;
89+
while (matcher.find(searchStart)) {
90+
var chunkEnd = matcher.start();
91+
if (chunkStart <= chunkEnd) {
92+
chunkOffsets.add(new ChunkOffset(chunkStart, chunkEnd));
93+
}
94+
// TODO: check what happens if it's an empty regex
95+
chunkStart = matcher.start();
96+
searchStart = matcher.end();
97+
}
98+
99+
if (chunkStart < input.length()) {
100+
chunkOffsets.add(new ChunkOffset(chunkStart, input.length()));
101+
}
102+
103+
return chunkOffsets;
104+
}
105+
106+
private List<ChunkOffset> mergeChunkOffsetsUpToMaxChunkSize(String input, List<ChunkOffset> chunkOffsets, int maxChunkSize) {
107+
if (chunkOffsets.size() < 2) {
108+
return chunkOffsets;
109+
}
110+
111+
List<ChunkOffset> mergedOffsets = new ArrayList<>();
112+
var mergedChunk = chunkOffsets.getFirst();
113+
for (int i = 1; i < chunkOffsets.size(); i++) {
114+
var potentialMergedChunk = new ChunkOffset(mergedChunk.start(), chunkOffsets.get(i).end());
115+
if (isChunkWithinMaxSize(input, potentialMergedChunk, maxChunkSize)) {
116+
mergedChunk = potentialMergedChunk;
117+
} else {
118+
mergedOffsets.add(mergedChunk);
119+
mergedChunk = chunkOffsets.get(i);
120+
}
121+
122+
if (i == chunkOffsets.size() - 1) {
123+
mergedOffsets.add(mergedChunk);
124+
}
125+
}
126+
return mergedOffsets;
127+
}
128+
129+
private List<ChunkOffset> chunkWithBackupChunker(String input, int maxChunkSize, int chunkOffset) {
130+
var chunks = new SentenceBoundaryChunker().chunk(input, new SentenceBoundaryChunkingSettings(maxChunkSize, 0));
131+
var chunksWithOffsets = new ArrayList<ChunkOffset>();
132+
for (var chunk : chunks) {
133+
chunksWithOffsets.add(new ChunkOffset(chunk.start() + chunkOffset, chunk.end() + chunkOffset));
134+
}
135+
return chunksWithOffsets;
136+
}
137+
}

0 commit comments

Comments
 (0)