Skip to content

Commit e92de38

Browse files
dan-rubinsteinelasticsearchmachineelasticmachine
authored
[8.19] Add recursive chunker (#126866) (#129656)
* Add recursive chunker (#126866) * Add recursive chunker * Update docs/changelog/126866.yaml * Clean up separator sets and add asMap function for RecrusiveChunkingSettings * Add javadoc for chunker, add tests, reduce word counting operations * Remove split merging and add long document unit test * [CI] Auto commit changes from spotless * Add markdown chunking tests and reduce substring calls * Clean up matcher logic * Add testing for not splitting after valid chunk is found --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Elastic Machine <[email protected]> * Update getFirst to get in recursive chunker tests --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Elastic Machine <[email protected]>
1 parent a6dd703 commit e92de38

File tree

18 files changed

+932
-91
lines changed

18 files changed

+932
-91
lines changed

docs/changelog/126866.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126866
2+
summary: Add recursive chunker
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
public enum ChunkingStrategy {
1717
WORD("word"),
1818
SENTENCE("sentence"),
19+
RECURSIVE("recursive"),
1920
NONE("none");
2021

2122
private final String chunkingStrategy;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
2828
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
2929
import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings;
30+
import org.elasticsearch.xpack.inference.chunking.RecursiveChunkingSettings;
3031
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
3132
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
3233
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
@@ -567,6 +568,9 @@ private static void addChunkingSettingsNamedWriteables(List<NamedWriteableRegist
567568
SentenceBoundaryChunkingSettings::new
568569
)
569570
);
571+
namedWriteables.add(
572+
new NamedWriteableRegistry.Entry(ChunkingSettings.class, RecursiveChunkingSettings.NAME, RecursiveChunkingSettings::new)
573+
);
570574
}
571575

572576
private static void addInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ public static Chunker fromChunkingStrategy(ChunkingStrategy chunkingStrategy) {
1919
case NONE -> NoopChunker.INSTANCE;
2020
case WORD -> new WordBoundaryChunker();
2121
case SENTENCE -> new SentenceBoundaryChunker();
22+
case RECURSIVE -> new RecursiveChunker();
2223
};
2324
}
2425
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.chunking;
9+
10+
import com.ibm.icu.text.BreakIterator;
11+
12+
public class ChunkerUtils {
13+
14+
// setText() should be applied before using this function.
15+
static int countWords(int start, int end, BreakIterator wordIterator) {
16+
assert start < end;
17+
wordIterator.preceding(start); // start of the current word
18+
19+
int boundary = wordIterator.current();
20+
int wordCount = 0;
21+
while (boundary != BreakIterator.DONE && boundary <= end) {
22+
int wordStatus = wordIterator.getRuleStatus();
23+
if (wordStatus != BreakIterator.WORD_NONE) {
24+
wordCount++;
25+
}
26+
boundary = wordIterator.next();
27+
}
28+
29+
return wordCount;
30+
}
31+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public static ChunkingSettings fromMap(Map<String, Object> settings, boolean ret
4848
case NONE -> NoneChunkingSettings.INSTANCE;
4949
case WORD -> WordBoundaryChunkingSettings.fromMap(new HashMap<>(settings));
5050
case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(new HashMap<>(settings));
51+
case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings));
5152
};
5253
}
5354
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ public enum ChunkingSettingsOptions {
1111
STRATEGY("strategy"),
1212
MAX_CHUNK_SIZE("max_chunk_size"),
1313
OVERLAP("overlap"),
14-
SENTENCE_OVERLAP("sentence_overlap");
14+
SENTENCE_OVERLAP("sentence_overlap"),
15+
SEPARATOR_SET("separator_set"),
16+
SEPARATORS("separators");
1517

1618
private final String chunkingSettingsOption;
1719

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.chunking;
9+
10+
import com.ibm.icu.text.BreakIterator;
11+
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.inference.ChunkingSettings;
14+
15+
import java.util.ArrayList;
16+
import java.util.List;
17+
import java.util.regex.Pattern;
18+
19+
/**
20+
* Split text into chunks recursively based on a list of separator regex strings.
21+
* The maximum chunk size is measured in words and controlled
22+
* by {@code maxNumberWordsPerChunk}. For each separator the chunker will go through the following process:
23+
* 1. Split the text on each regex match of the separator.
24+
* 2. For each chunk after the merge:
25+
* 1. Return it if it is within the maximum chunk size.
26+
* 2. Repeat the process using the next separator in the list if the chunk exceeds the maximum chunk size.
27+
* If there are no more separators left to try, run the {@code SentenceBoundaryChunker} with the provided
28+
* max chunk size and no overlaps.
29+
*/
30+
public class RecursiveChunker implements Chunker {
31+
private final BreakIterator wordIterator;
32+
33+
public RecursiveChunker() {
34+
wordIterator = BreakIterator.getWordInstance();
35+
}
36+
37+
@Override
38+
public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
39+
if (chunkingSettings instanceof RecursiveChunkingSettings recursiveChunkingSettings) {
40+
return chunk(
41+
input,
42+
new ChunkOffset(0, input.length()),
43+
recursiveChunkingSettings.getSeparators(),
44+
recursiveChunkingSettings.getMaxChunkSize(),
45+
0
46+
);
47+
} else {
48+
throw new IllegalArgumentException(
49+
Strings.format("RecursiveChunker can't use ChunkingSettings with strategy [%s]", chunkingSettings.getChunkingStrategy())
50+
);
51+
}
52+
}
53+
54+
private List<ChunkOffset> chunk(String input, ChunkOffset offset, List<String> separators, int maxChunkSize, int separatorIndex) {
55+
if (offset.start() == offset.end() || isChunkWithinMaxSize(buildChunkOffsetAndCount(input, offset), maxChunkSize)) {
56+
return List.of(offset);
57+
}
58+
59+
if (separatorIndex > separators.size() - 1) {
60+
return chunkWithBackupChunker(input, offset, maxChunkSize);
61+
}
62+
63+
var potentialChunks = splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex));
64+
var actualChunks = new ArrayList<ChunkOffset>();
65+
for (var potentialChunk : potentialChunks) {
66+
if (isChunkWithinMaxSize(potentialChunk, maxChunkSize)) {
67+
actualChunks.add(potentialChunk.chunkOffset());
68+
} else {
69+
actualChunks.addAll(chunk(input, potentialChunk.chunkOffset(), separators, maxChunkSize, separatorIndex + 1));
70+
}
71+
}
72+
73+
return actualChunks;
74+
}
75+
76+
private boolean isChunkWithinMaxSize(ChunkOffsetAndCount chunkOffsetAndCount, int maxChunkSize) {
77+
return chunkOffsetAndCount.wordCount <= maxChunkSize;
78+
}
79+
80+
private ChunkOffsetAndCount buildChunkOffsetAndCount(String fullText, ChunkOffset offset) {
81+
wordIterator.setText(fullText);
82+
return new ChunkOffsetAndCount(offset, ChunkerUtils.countWords(offset.start(), offset.end(), wordIterator));
83+
}
84+
85+
private List<ChunkOffsetAndCount> splitTextBySeparatorRegex(String input, ChunkOffset offset, String separatorRegex) {
86+
var pattern = Pattern.compile(separatorRegex, Pattern.MULTILINE);
87+
var matcher = pattern.matcher(input).region(offset.start(), offset.end());
88+
89+
var chunkOffsets = new ArrayList<ChunkOffsetAndCount>();
90+
int chunkStart = offset.start();
91+
while (matcher.find()) {
92+
var chunkEnd = matcher.start();
93+
94+
if (chunkStart < chunkEnd) {
95+
chunkOffsets.add(buildChunkOffsetAndCount(input, new ChunkOffset(chunkStart, chunkEnd)));
96+
}
97+
chunkStart = chunkEnd;
98+
}
99+
100+
if (chunkStart < offset.end()) {
101+
chunkOffsets.add(buildChunkOffsetAndCount(input, new ChunkOffset(chunkStart, offset.end())));
102+
}
103+
104+
return chunkOffsets;
105+
}
106+
107+
private List<ChunkOffset> chunkWithBackupChunker(String input, ChunkOffset offset, int maxChunkSize) {
108+
var chunks = new SentenceBoundaryChunker().chunk(
109+
input.substring(offset.start(), offset.end()),
110+
new SentenceBoundaryChunkingSettings(maxChunkSize, 0)
111+
);
112+
var chunksWithOffsets = new ArrayList<ChunkOffset>();
113+
for (var chunk : chunks) {
114+
chunksWithOffsets.add(new ChunkOffset(chunk.start() + offset.start(), chunk.end() + offset.start()));
115+
}
116+
return chunksWithOffsets;
117+
}
118+
119+
private record ChunkOffsetAndCount(ChunkOffset chunkOffset, int wordCount) {}
120+
}
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.chunking;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.Strings;
12+
import org.elasticsearch.common.ValidationException;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.inference.ChunkingSettings;
16+
import org.elasticsearch.inference.ChunkingStrategy;
17+
import org.elasticsearch.inference.ModelConfigurations;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
20+
21+
import java.io.IOException;
22+
import java.util.Arrays;
23+
import java.util.EnumSet;
24+
import java.util.List;
25+
import java.util.Locale;
26+
import java.util.Map;
27+
import java.util.Objects;
28+
import java.util.Set;
29+
30+
public class RecursiveChunkingSettings implements ChunkingSettings {
31+
public static final String NAME = "RecursiveChunkingSettings";
32+
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.RECURSIVE;
33+
private static final int MAX_CHUNK_SIZE_LOWER_LIMIT = 10;
34+
private static final int MAX_CHUNK_SIZE_UPPER_LIMIT = 300;
35+
36+
private static final Set<String> VALID_KEYS = Set.of(
37+
ChunkingSettingsOptions.STRATEGY.toString(),
38+
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
39+
ChunkingSettingsOptions.SEPARATOR_SET.toString(),
40+
ChunkingSettingsOptions.SEPARATORS.toString()
41+
);
42+
43+
private final int maxChunkSize;
44+
private final List<String> separators;
45+
46+
public RecursiveChunkingSettings(int maxChunkSize, List<String> separators) {
47+
this.maxChunkSize = maxChunkSize;
48+
this.separators = separators == null ? SeparatorSet.PLAINTEXT.getSeparators() : separators;
49+
}
50+
51+
public RecursiveChunkingSettings(StreamInput in) throws IOException {
52+
maxChunkSize = in.readInt();
53+
separators = in.readCollectionAsList(StreamInput::readString);
54+
}
55+
56+
public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
57+
ValidationException validationException = new ValidationException();
58+
59+
var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray();
60+
if (invalidSettings.length > 0) {
61+
validationException.addValidationError(
62+
Strings.format("Recursive chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings))
63+
);
64+
}
65+
66+
Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerBetween(
67+
map,
68+
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
69+
MAX_CHUNK_SIZE_LOWER_LIMIT,
70+
MAX_CHUNK_SIZE_UPPER_LIMIT,
71+
ModelConfigurations.CHUNKING_SETTINGS,
72+
validationException
73+
);
74+
75+
SeparatorSet separatorSet = ServiceUtils.extractOptionalEnum(
76+
map,
77+
ChunkingSettingsOptions.SEPARATOR_SET.toString(),
78+
ModelConfigurations.CHUNKING_SETTINGS,
79+
SeparatorSet::fromString,
80+
EnumSet.allOf(SeparatorSet.class),
81+
validationException
82+
);
83+
84+
List<String> separators = ServiceUtils.extractOptionalList(
85+
map,
86+
ChunkingSettingsOptions.SEPARATORS.toString(),
87+
String.class,
88+
validationException
89+
);
90+
91+
if (separators != null && separatorSet != null) {
92+
validationException.addValidationError("Recursive chunking settings can not have both separators and separator_set");
93+
}
94+
95+
if (separatorSet != null) {
96+
separators = separatorSet.getSeparators();
97+
} else if (separators != null && separators.isEmpty()) {
98+
validationException.addValidationError("Recursive chunking settings can not have an empty list of separators");
99+
}
100+
101+
if (validationException.validationErrors().isEmpty() == false) {
102+
throw validationException;
103+
}
104+
105+
return new RecursiveChunkingSettings(maxChunkSize, separators);
106+
}
107+
108+
public int getMaxChunkSize() {
109+
return maxChunkSize;
110+
}
111+
112+
public List<String> getSeparators() {
113+
return separators;
114+
}
115+
116+
@Override
117+
public ChunkingStrategy getChunkingStrategy() {
118+
return STRATEGY;
119+
}
120+
121+
@Override
122+
public Map<String, Object> asMap() {
123+
return Map.of(
124+
ChunkingSettingsOptions.STRATEGY.toString(),
125+
STRATEGY.toString().toLowerCase(Locale.ROOT),
126+
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
127+
maxChunkSize,
128+
ChunkingSettingsOptions.SEPARATORS.toString(),
129+
separators
130+
);
131+
}
132+
133+
@Override
134+
public String getWriteableName() {
135+
return NAME;
136+
}
137+
138+
@Override
139+
public TransportVersion getMinimalSupportedVersion() {
140+
return null; // TODO: Add transport version
141+
}
142+
143+
@Override
144+
public void writeTo(StreamOutput out) throws IOException {
145+
out.writeInt(maxChunkSize);
146+
out.writeCollection(separators, StreamOutput::writeString);
147+
}
148+
149+
@Override
150+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
151+
builder.startObject();
152+
{
153+
builder.field(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY);
154+
builder.field(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize);
155+
builder.field(ChunkingSettingsOptions.SEPARATORS.toString(), separators);
156+
}
157+
builder.endObject();
158+
return builder;
159+
}
160+
161+
@Override
162+
public boolean equals(Object o) {
163+
if (this == o) return true;
164+
if (o == null || getClass() != o.getClass()) return false;
165+
RecursiveChunkingSettings that = (RecursiveChunkingSettings) o;
166+
return Objects.equals(maxChunkSize, that.maxChunkSize) && Objects.equals(separators, that.separators);
167+
}
168+
169+
@Override
170+
public int hashCode() {
171+
return Objects.hash(maxChunkSize, separators);
172+
}
173+
}

0 commit comments

Comments
 (0)