Skip to content

Commit 58b835b

Browse files
authored
[ML] Add sentence overlap option to the sentence chunking settings (elastic#114461) (elastic#114626)
1 parent e5a3e94 commit 58b835b

File tree

10 files changed

+291
-32
lines changed

10 files changed

+291
-32
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ static TransportVersion def(int id) {
240240
public static final TransportVersion SIMULATE_INDEX_TEMPLATES_SUBSTITUTIONS = def(8_764_00_0);
241241
public static final TransportVersion RETRIEVERS_TELEMETRY_ADDED = def(8_765_00_0);
242242
public static final TransportVersion ESQL_CACHED_STRING_SERIALIZATION = def(8_766_00_0);
243-
public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDE1D = def(8_767_00_0);
243+
public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDED = def(8_767_00_0);
244244
public static final TransportVersion OPT_IN_ESQL_CCS_EXECUTION_INFO = def(8_768_00_0);
245245

246246
/*

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
public enum ChunkingSettingsOptions {
1111
STRATEGY("strategy"),
1212
MAX_CHUNK_SIZE("max_chunk_size"),
13-
OVERLAP("overlap");
13+
OVERLAP("overlap"),
14+
SENTENCE_OVERLAP("sentence_overlap");
1415

1516
private final String chunkingSettingsOption;
1617

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public class SentenceBoundaryChunker implements Chunker {
3434
public SentenceBoundaryChunker() {
3535
sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT);
3636
wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
37+
3738
}
3839

3940
/**
@@ -46,7 +47,7 @@ public SentenceBoundaryChunker() {
4647
@Override
4748
public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
4849
if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) {
49-
return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize);
50+
return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize, sentenceBoundaryChunkingSettings.sentenceOverlap > 0);
5051
} else {
5152
throw new IllegalArgumentException(
5253
Strings.format(
@@ -64,7 +65,7 @@ public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
6465
* @param maxNumberWordsPerChunk Maximum size of the chunk
6566
* @return The input text chunked
6667
*/
67-
public List<String> chunk(String input, int maxNumberWordsPerChunk) {
68+
public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
6869
var chunks = new ArrayList<String>();
6970

7071
sentenceIterator.setText(input);
@@ -75,24 +76,46 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk) {
7576
int sentenceStart = 0;
7677
int chunkWordCount = 0;
7778

79+
int wordsInPrecedingSentenceCount = 0;
80+
int previousSentenceStart = 0;
81+
7882
int boundary = sentenceIterator.next();
7983

8084
while (boundary != BreakIterator.DONE) {
8185
int sentenceEnd = sentenceIterator.current();
82-
int countWordsInSentence = countWords(sentenceStart, sentenceEnd);
86+
int wordsInSentenceCount = countWords(sentenceStart, sentenceEnd);
8387

84-
if (chunkWordCount + countWordsInSentence > maxNumberWordsPerChunk) {
88+
if (chunkWordCount + wordsInSentenceCount > maxNumberWordsPerChunk) {
8589
// over the max chunk size, roll back to the last sentence
8690

91+
int nextChunkWordCount = wordsInSentenceCount;
8792
if (chunkWordCount > 0) {
8893
// add a new chunk containing all the input up to this sentence
8994
chunks.add(input.substring(chunkStart, chunkEnd));
90-
chunkStart = chunkEnd;
91-
chunkWordCount = countWordsInSentence; // the next chunk will contain this sentence
95+
96+
if (includePrecedingSentence) {
97+
if (wordsInPrecedingSentenceCount + wordsInSentenceCount > maxNumberWordsPerChunk) {
98+
// cut the last sentence
99+
int numWordsToSkip = numWordsToSkipInPreviousSentence(wordsInPrecedingSentenceCount, maxNumberWordsPerChunk);
100+
101+
chunkStart = skipWords(input, previousSentenceStart, numWordsToSkip);
102+
chunkWordCount = (wordsInPrecedingSentenceCount - numWordsToSkip) + wordsInSentenceCount;
103+
} else {
104+
chunkWordCount = wordsInPrecedingSentenceCount + wordsInSentenceCount;
105+
chunkStart = previousSentenceStart;
106+
}
107+
108+
nextChunkWordCount = chunkWordCount;
109+
} else {
110+
chunkStart = chunkEnd;
111+
chunkWordCount = wordsInSentenceCount; // the next chunk will contain this sentence
112+
}
92113
}
93114

94-
if (countWordsInSentence > maxNumberWordsPerChunk) {
95-
// This sentence is bigger than the max chunk size.
115+
// Is the next chunk larger than max chunk size?
116+
// If so split it
117+
if (nextChunkWordCount > maxNumberWordsPerChunk) {
118+
// This sentence (and optional overlap) is bigger than the max chunk size.
96119
// Split the sentence on the word boundary
97120
var sentenceSplits = splitLongSentence(
98121
input.substring(chunkStart, sentenceEnd),
@@ -113,7 +136,12 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk) {
113136
chunkWordCount = sentenceSplits.get(i).wordCount();
114137
}
115138
} else {
116-
chunkWordCount += countWordsInSentence;
139+
chunkWordCount += wordsInSentenceCount;
140+
}
141+
142+
if (includePrecedingSentence) {
143+
previousSentenceStart = sentenceStart;
144+
wordsInPrecedingSentenceCount = wordsInSentenceCount;
117145
}
118146

119147
sentenceStart = sentenceEnd;
@@ -133,6 +161,45 @@ static List<WordBoundaryChunker.ChunkPosition> splitLongSentence(String text, in
133161
return new WordBoundaryChunker().chunkPositions(text, maxNumberOfWords, overlap);
134162
}
135163

164+
static int numWordsToSkipInPreviousSentence(int wordsInPrecedingSentenceCount, int maxNumberWordsPerChunk) {
165+
var maxWordsInOverlap = maxWordsInOverlap(maxNumberWordsPerChunk);
166+
if (wordsInPrecedingSentenceCount > maxWordsInOverlap) {
167+
return wordsInPrecedingSentenceCount - maxWordsInOverlap;
168+
} else {
169+
return 0;
170+
}
171+
}
172+
173+
static int maxWordsInOverlap(int maxNumberWordsPerChunk) {
174+
return Math.min(maxNumberWordsPerChunk / 2, 20);
175+
}
176+
177+
private int skipWords(String input, int start, int numWords) {
178+
var itr = BreakIterator.getWordInstance(Locale.ROOT);
179+
itr.setText(input);
180+
return skipWords(start, numWords, itr);
181+
}
182+
183+
static int skipWords(int start, int numWords, BreakIterator wordIterator) {
184+
wordIterator.preceding(start); // start of the current word
185+
186+
int boundary = wordIterator.current();
187+
int wordCount = 0;
188+
while (boundary != BreakIterator.DONE && wordCount < numWords) {
189+
int wordStatus = wordIterator.getRuleStatus();
190+
if (wordStatus != BreakIterator.WORD_NONE) {
191+
wordCount++;
192+
}
193+
boundary = wordIterator.next();
194+
}
195+
196+
if (boundary == BreakIterator.DONE) {
197+
return wordIterator.last();
198+
} else {
199+
return boundary;
200+
}
201+
}
202+
136203
private int countWords(int start, int end) {
137204
return countWords(start, end, this.wordIterator);
138205
}
@@ -157,6 +224,6 @@ static int countWords(int start, int end, BreakIterator wordIterator) {
157224
}
158225

159226
private static int overlapForChunkSize(int chunkSize) {
160-
return (chunkSize - 1) / 2;
227+
return Math.min(20, (chunkSize - 1) / 2);
161228
}
162229
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.common.ValidationException;
1414
import org.elasticsearch.common.io.stream.StreamInput;
1515
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.core.Nullable;
1617
import org.elasticsearch.inference.ChunkingSettings;
1718
import org.elasticsearch.inference.ChunkingStrategy;
1819
import org.elasticsearch.inference.ModelConfigurations;
@@ -30,16 +31,25 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
3031
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.SENTENCE;
3132
private static final Set<String> VALID_KEYS = Set.of(
3233
ChunkingSettingsOptions.STRATEGY.toString(),
33-
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString()
34+
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
35+
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString()
3436
);
37+
38+
private static int DEFAULT_OVERLAP = 0;
39+
3540
protected final int maxChunkSize;
41+
protected int sentenceOverlap = DEFAULT_OVERLAP;
3642

37-
public SentenceBoundaryChunkingSettings(Integer maxChunkSize) {
43+
public SentenceBoundaryChunkingSettings(Integer maxChunkSize, @Nullable Integer sentenceOverlap) {
3844
this.maxChunkSize = maxChunkSize;
45+
this.sentenceOverlap = sentenceOverlap == null ? DEFAULT_OVERLAP : sentenceOverlap;
3946
}
4047

4148
public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException {
4249
maxChunkSize = in.readInt();
50+
if (in.getTransportVersion().onOrAfter(TransportVersions.CHUNK_SENTENCE_OVERLAP_SETTING_ADDED)) {
51+
sentenceOverlap = in.readVInt();
52+
}
4353
}
4454

4555
public static SentenceBoundaryChunkingSettings fromMap(Map<String, Object> map) {
@@ -59,11 +69,24 @@ public static SentenceBoundaryChunkingSettings fromMap(Map<String, Object> map)
5969
validationException
6070
);
6171

72+
Integer sentenceOverlap = ServiceUtils.extractOptionalPositiveInteger(
73+
map,
74+
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(),
75+
ModelConfigurations.CHUNKING_SETTINGS,
76+
validationException
77+
);
78+
79+
if (sentenceOverlap != null && sentenceOverlap > 1) {
80+
validationException.addValidationError(
81+
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString() + "[" + sentenceOverlap + "] must be either 0 or 1"
82+
); // todo better
83+
}
84+
6285
if (validationException.validationErrors().isEmpty() == false) {
6386
throw validationException;
6487
}
6588

66-
return new SentenceBoundaryChunkingSettings(maxChunkSize);
89+
return new SentenceBoundaryChunkingSettings(maxChunkSize, sentenceOverlap);
6790
}
6891

6992
@Override
@@ -72,6 +95,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7295
{
7396
builder.field(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY);
7497
builder.field(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize);
98+
builder.field(ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), sentenceOverlap);
7599
}
76100
builder.endObject();
77101
return builder;
@@ -90,6 +114,9 @@ public TransportVersion getMinimalSupportedVersion() {
90114
@Override
91115
public void writeTo(StreamOutput out) throws IOException {
92116
out.writeInt(maxChunkSize);
117+
if (out.getTransportVersion().onOrAfter(TransportVersions.CHUNK_SENTENCE_OVERLAP_SETTING_ADDED)) {
118+
out.writeVInt(sentenceOverlap);
119+
}
93120
}
94121

95122
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public static WordBoundaryChunkingSettings fromMap(Map<String, Object> map) {
5252
var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray();
5353
if (invalidSettings.length > 0) {
5454
validationException.addValidationError(
55-
Strings.format("Sentence based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings))
55+
Strings.format("Word based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings))
5656
);
5757
}
5858

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ private Map<Map<String, Object>, ChunkingSettings> chunkingSettingsMapToChunking
5656
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
5757
maxChunkSize
5858
),
59-
new SentenceBoundaryChunkingSettings(maxChunkSize)
59+
new SentenceBoundaryChunkingSettings(maxChunkSize, 1)
6060
);
6161
}
6262
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public static ChunkingSettings createRandomChunkingSettings() {
2525
return new WordBoundaryChunkingSettings(maxChunkSize, randomIntBetween(1, maxChunkSize / 2));
2626
}
2727
case SENTENCE -> {
28-
return new SentenceBoundaryChunkingSettings(randomNonNegativeInt());
28+
return new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), randomBoolean() ? 0 : 1);
2929
}
3030
default -> throw new IllegalArgumentException("Unsupported random strategy [" + randomStrategy + "]");
3131
}

0 commit comments

Comments
 (0)