Skip to content

Commit 0685124

Browse files
Clean up separator sets and add asMap function for RecrusiveChunkingSettings
1 parent 8418223 commit 0685124

File tree

6 files changed

+153
-67
lines changed

6 files changed

+153
-67
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import java.util.regex.Pattern;
1818

1919
public class RecursiveChunker implements Chunker {
20-
private BreakIterator wordIterator;
20+
private final BreakIterator wordIterator;
2121

2222
public RecursiveChunker() {
2323
wordIterator = BreakIterator.getWordInstance();
@@ -34,33 +34,31 @@ public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings)
3434
}
3535
}
3636

37-
private List<ChunkOffset> chunk(String input, List<String> splitters, int maxChunkSize, int splitterIndex, int chunkOffset) {
37+
private List<ChunkOffset> chunk(String input, List<String> separators, int maxChunkSize, int separatorIndex, int chunkOffset) {
3838
if (input.length() < 2 || isChunkWithinMaxSize(input, new ChunkOffset(0, input.length()), maxChunkSize)) {
3939
return List.of(new ChunkOffset(chunkOffset, chunkOffset + input.length()));
4040
}
4141

42-
if (splitterIndex > splitters.size() - 1) {
42+
if (separatorIndex > separators.size() - 1) {
4343
return chunkWithBackupChunker(input, maxChunkSize, chunkOffset);
4444
}
4545

46-
var potentialChunks = splitAndMergeChunks(input, splitters.get(splitterIndex), maxChunkSize);
46+
var potentialChunks = mergeChunkOffsetsUpToMaxChunkSize(
47+
input,
48+
splitTextBySeparatorRegex(input, separators.get(separatorIndex)),
49+
maxChunkSize
50+
);
4751
var actualChunks = new ArrayList<ChunkOffset>();
4852
for (var potentialChunk : potentialChunks) {
49-
// TODO: Decide if we want to allow the first condition? Ex. "## This is a test...." split on "#" will create
50-
// a chunk with just "#" If the rest of the sentence is bigger than the maximum chunk size. We can either stop this by
51-
// doing something like splitting on the "current splitter" but skipping anything that matches the previous splitters
52-
// Similarly we could make the splitter a regex and update the default splitters to specifically match just the value without
53-
// Duplicate values around it
54-
// Or we can merge chunks across all levels after everything is done instead of merging them after each split
55-
if (potentialChunk.start() == potentialChunk.end() || isChunkWithinMaxSize(input, potentialChunk, maxChunkSize)) {
53+
if (isChunkWithinMaxSize(input, potentialChunk, maxChunkSize)) {
5654
actualChunks.add(new ChunkOffset(chunkOffset + potentialChunk.start(), chunkOffset + potentialChunk.end()));
5755
} else {
5856
actualChunks.addAll(
5957
chunk(
6058
input.substring(potentialChunk.start(), potentialChunk.end()),
61-
splitters,
59+
separators,
6260
maxChunkSize,
63-
splitterIndex + 1,
61+
separatorIndex + 1,
6462
chunkOffset + potentialChunk.start()
6563
)
6664
);
@@ -75,10 +73,6 @@ private boolean isChunkWithinMaxSize(String fullText, ChunkOffset chunk, int max
7573
return ChunkerUtils.countWords(chunk.start(), chunk.end(), wordIterator) <= maxChunkSize;
7674
}
7775

78-
private List<ChunkOffset> splitAndMergeChunks(String input, String separator, int maxChunkSize) {
79-
return mergeChunkOffsetsUpToMaxChunkSize(input, splitTextBySeparatorRegex(input, separator), maxChunkSize);
80-
}
81-
8276
private List<ChunkOffset> splitTextBySeparatorRegex(String input, String separatorRegex) {
8377
var pattern = Pattern.compile(separatorRegex);
8478
var matcher = pattern.matcher(input);
@@ -88,10 +82,9 @@ private List<ChunkOffset> splitTextBySeparatorRegex(String input, String separat
8882
int searchStart = 0;
8983
while (matcher.find(searchStart)) {
9084
var chunkEnd = matcher.start();
91-
if (chunkStart <= chunkEnd) {
85+
if (chunkStart < chunkEnd) {
9286
chunkOffsets.add(new ChunkOffset(chunkStart, chunkEnd));
9387
}
94-
// TODO: check what happens if it's an empty regex
9588
chunkStart = matcher.start();
9689
searchStart = matcher.end();
9790
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Arrays;
2323
import java.util.EnumSet;
2424
import java.util.List;
25+
import java.util.Locale;
2526
import java.util.Map;
2627
import java.util.Objects;
2728
import java.util.Set;
@@ -39,19 +40,12 @@ public class RecursiveChunkingSettings implements ChunkingSettings {
3940
ChunkingSettingsOptions.SEPARATORS.toString()
4041
);
4142

42-
protected static final Map<SeparatorSet, List<String>> SEPARATOR_SETS = Map.of(
43-
SeparatorSet.DEFAULT,
44-
List.of("(?<!\n)\n\n(?!\n)", "(?<!\n)\n(?!\n)"),
45-
SeparatorSet.MARKDOWN,
46-
List.of("(?<!#)###(?!#)", "(?<!#)##(?!#)", "(?<!#)#(?!#)") // TODO: What other ones do we want here?
47-
);
48-
4943
private final int maxChunkSize;
5044
private final List<String> separators;
5145

5246
public RecursiveChunkingSettings(int maxChunkSize, List<String> separators) {
5347
this.maxChunkSize = maxChunkSize;
54-
this.separators = separators == null ? SEPARATOR_SETS.get(SeparatorSet.DEFAULT) : separators;
48+
this.separators = separators == null ? SeparatorSet.PLAINTEXT.getSeparators() : separators;
5549
}
5650

5751
public RecursiveChunkingSettings(StreamInput in) throws IOException {
@@ -99,7 +93,9 @@ public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
9993
}
10094

10195
if (separatorSet != null) {
102-
separators = SEPARATOR_SETS.get(separatorSet);
96+
separators = separatorSet.getSeparators();
97+
} else if (separators != null && separators.isEmpty()) {
98+
validationException.addValidationError("Recursive chunking settings can not have an empty list of separators");
10399
}
104100

105101
if (validationException.validationErrors().isEmpty() == false) {
@@ -122,6 +118,18 @@ public ChunkingStrategy getChunkingStrategy() {
122118
return STRATEGY;
123119
}
124120

121+
@Override
122+
public Map<String, Object> asMap() {
123+
return Map.of(
124+
ChunkingSettingsOptions.STRATEGY.toString(),
125+
STRATEGY.toString().toLowerCase(Locale.ROOT),
126+
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
127+
maxChunkSize,
128+
ChunkingSettingsOptions.SEPARATORS.toString(),
129+
separators
130+
);
131+
}
132+
125133
@Override
126134
public String getWriteableName() {
127135
return NAME;
@@ -162,23 +170,4 @@ public boolean equals(Object o) {
162170
public int hashCode() {
163171
return Objects.hash(maxChunkSize, separators);
164172
}
165-
166-
protected enum SeparatorSet {
167-
DEFAULT("default"),
168-
MARKDOWN("markdown");
169-
170-
private final String name;
171-
172-
SeparatorSet(String name) {
173-
this.name = name;
174-
}
175-
176-
public static SeparatorSet fromString(String name) {
177-
return EnumSet.allOf(SeparatorSet.class)
178-
.stream()
179-
.filter(ss -> ss.name.equals(name))
180-
.findFirst()
181-
.orElseThrow(() -> new IllegalArgumentException(Strings.format("Invalid separator set %s", name)));
182-
}
183-
}
184173
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.chunking;
9+
10+
import java.util.List;
11+
import java.util.Locale;
12+
13+
public enum SeparatorSet {
14+
PLAINTEXT("plaintext"),
15+
MARKDOWN("markdown");
16+
17+
private final String name;
18+
19+
SeparatorSet(String name) {
20+
this.name = name;
21+
}
22+
23+
public static SeparatorSet fromString(String name) {
24+
return valueOf(name.trim().toUpperCase(Locale.ROOT));
25+
}
26+
27+
public List<String> getSeparators() {
28+
return switch (this) {
29+
case PLAINTEXT -> List.of("\n\n", "\n");
30+
case MARKDOWN -> List.of("\n#{1,6} ", "^(?!\\s*$).*\\n*{3,}\\n", "^(?!\\s*$).*\\n-{3,}\\n", "^(?!\\s*$).*\\n_{3,}\\n");
31+
};
32+
}
33+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
import org.elasticsearch.inference.ChunkingSettings;
1111
import org.elasticsearch.test.ESTestCase;
1212

13-
import java.util.ArrayList;
1413
import java.util.List;
1514

1615
public class RecursiveChunkerTests extends ESTestCase {
1716

18-
private final List<String> TEST_SEPARATORS = List.of("\n\n", "\n", "\f", "\t", "###", "##", "#");
17+
private final List<String> TEST_SEPARATORS = List.of("\n\n", "\n", "\f", "\t", "#");
1918
private final String TEST_SENTENCE = "This is a test sentence that has ten total words. ";
2019

2120
public void testChunkWithInvalidChunkingSettings() {
@@ -81,7 +80,7 @@ public void testChunkInputRequiresMultipleSplitsWithNoMerges() {
8180

8281
var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length();
8382
var expectedSecondChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.get(1).length();
84-
var expectedThirdChunkOffsetEnd = TEST_SENTENCE.length() * 3 + +separators.getFirst().length() + separators.get(1).length();
83+
var expectedThirdChunkOffsetEnd = TEST_SENTENCE.length() * 3 + separators.getFirst().length() + separators.get(1).length();
8584
assertExpectedChunksGenerated(
8685
input,
8786
settings,
@@ -154,13 +153,6 @@ private List<String> generateRandomSeparators() {
154153
}
155154

156155
private RecursiveChunkingSettings generateChunkingSettings(int maxChunkSize, List<String> separators) {
157-
// Convert separators to regex with lookbehind and lookahead assertions to avoid splitting separators that are subsets of other
158-
// separators (ex. if \n is a separator, then \n\n should not be split into two \n).
159-
var separatorRegexList = new ArrayList<String>();
160-
for (var separator : separators) {
161-
separatorRegexList.add("(?<!" + separator + ")" + separator + "(?!" + separator + ")");
162-
}
163-
164-
return new RecursiveChunkingSettings(maxChunkSize, separatorRegexList);
156+
return new RecursiveChunkingSettings(maxChunkSize, separators);
165157
}
166158
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,65 @@
2121

2222
public class RecursiveChunkingSettingsTests extends AbstractWireSerializingTestCase<RecursiveChunkingSettings> {
2323

24-
public void testFromMapValidSettings() {
25-
Map<String, Object> validSettings = buildChunkingSettingsMap(50, Optional.empty(), Optional.of(List.of("\n\n", "\n")));
24+
public void testFromMapValidSettingsWithSeparators() {
25+
var maxChunkSize = randomIntBetween(10, 300);
26+
var separators = randomList(1, 10, () -> randomAlphaOfLength(1));
27+
Map<String, Object> validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.empty(), Optional.of(separators));
2628

2729
RecursiveChunkingSettings settings = RecursiveChunkingSettings.fromMap(validSettings);
2830

29-
assertEquals(50, settings.getMaxChunkSize());
30-
assertEquals(List.of("\n\n", "\n"), settings.getSeparators());
31+
assertEquals(maxChunkSize, settings.getMaxChunkSize());
32+
assertEquals(separators, settings.getSeparators());
33+
}
34+
35+
public void testFromMapValidSettingsWithSeparatorSet() {
36+
var maxChunkSize = randomIntBetween(10, 300);
37+
var separatorSet = randomFrom(SeparatorSet.values());
38+
Map<String, Object> validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.of(separatorSet.name()), Optional.empty());
39+
40+
RecursiveChunkingSettings settings = RecursiveChunkingSettings.fromMap(validSettings);
41+
42+
assertEquals(maxChunkSize, settings.getMaxChunkSize());
43+
assertEquals(separatorSet.getSeparators(), settings.getSeparators());
3144
}
3245

3346
public void testFromMapMaxChunkSizeTooSmall() {
34-
Map<String, Object> invalidSettings = buildChunkingSettingsMap(5, Optional.empty(), Optional.empty());
47+
Map<String, Object> invalidSettings = buildChunkingSettingsMap(randomIntBetween(0, 9), Optional.empty(), Optional.empty());
3548

3649
assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
3750
}
3851

3952
public void testFromMapMaxChunkSizeTooLarge() {
40-
Map<String, Object> invalidSettings = buildChunkingSettingsMap(500, Optional.empty(), Optional.empty());
53+
Map<String, Object> invalidSettings = buildChunkingSettingsMap(randomIntBetween(301, 500), Optional.empty(), Optional.empty());
4154

4255
assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
4356
}
4457

4558
public void testFromMapInvalidSeparatorSet() {
46-
Map<String, Object> invalidSettings = buildChunkingSettingsMap(50, Optional.of("invalid"), Optional.empty());
59+
Map<String, Object> invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.of("invalid"), Optional.empty());
4760

4861
assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
4962
}
5063

5164
public void testFromMapInvalidSettingKey() {
52-
Map<String, Object> invalidSettings = buildChunkingSettingsMap(50, Optional.empty(), Optional.empty());
65+
Map<String, Object> invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.empty(), Optional.empty());
5366
invalidSettings.put("invalid_key", "invalid_value");
5467

5568
assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
5669
}
5770

5871
public void testFromMapBothSeparatorsAndSeparatorSet() {
59-
Map<String, Object> invalidSettings = buildChunkingSettingsMap(50, Optional.of("default"), Optional.of(List.of("\n\n", "\n")));
72+
Map<String, Object> invalidSettings = buildChunkingSettingsMap(
73+
randomIntBetween(10, 300),
74+
Optional.of("default"),
75+
Optional.of(List.of("\n\n", "\n"))
76+
);
77+
78+
assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
79+
}
80+
81+
public void testFromMapEmptySeparators() {
82+
Map<String, Object> invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.empty(), Optional.of(List.of()));
6083

6184
assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
6285
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
2424
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
2525
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
26+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalList;
2627
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
2728
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong;
2829
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
@@ -469,6 +470,61 @@ public void testExtractOptionalString_AddsException_WhenFieldIsEmpty() {
469470
assertThat(validation.validationErrors().get(0), is("[scope] Invalid value empty string. [key] must be a non-empty string"));
470471
}
471472

473+
public void testExtractOptionalList_CreatesList() {
474+
var validation = new ValidationException();
475+
var list = List.of(randomAlphaOfLength(10), randomAlphaOfLength(10));
476+
477+
Map<String, Object> map = modifiableMap(Map.of("key", list));
478+
assertEquals(list, extractOptionalList(map, "key", String.class, validation));
479+
assertTrue(validation.validationErrors().isEmpty());
480+
assertTrue(map.isEmpty());
481+
}
482+
483+
public void testExtractOptionalList_AddsException_WhenFieldDoesNotExist() {
484+
var validation = new ValidationException();
485+
validation.addValidationError("previous error");
486+
Map<String, Object> map = modifiableMap(Map.of("key", List.of(randomAlphaOfLength(10), randomAlphaOfLength(10))));
487+
assertNull(extractOptionalList(map, "abc", String.class, validation));
488+
assertThat(validation.validationErrors(), hasSize(1));
489+
assertThat(map.size(), is(1));
490+
}
491+
492+
public void testExtractOptionalList_AddsException_WhenFieldIsEmpty() {
493+
var validation = new ValidationException();
494+
validation.addValidationError("previous error");
495+
Map<String, Object> map = modifiableMap(Map.of("key", ""));
496+
assertNull(extractOptionalList(map, "key", String.class, validation));
497+
assertFalse(validation.validationErrors().isEmpty());
498+
assertTrue(map.isEmpty());
499+
}
500+
501+
public void testExtractOptionalList_AddsException_WhenFieldIsNotAList() {
502+
var validation = new ValidationException();
503+
validation.addValidationError("previous error");
504+
Map<String, Object> map = modifiableMap(Map.of("key", 1));
505+
assertNull(extractOptionalList(map, "key", String.class, validation));
506+
assertFalse(validation.validationErrors().isEmpty());
507+
assertTrue(map.isEmpty());
508+
}
509+
510+
public void testExtractOptionalList_AddsException_WhenFieldIsNotAListOfTheCorrectType() {
511+
var validation = new ValidationException();
512+
validation.addValidationError("previous error");
513+
Map<String, Object> map = modifiableMap(Map.of("key", List.of(1, 2)));
514+
assertNull(extractOptionalList(map, "key", String.class, validation));
515+
assertFalse(validation.validationErrors().isEmpty());
516+
assertTrue(map.isEmpty());
517+
}
518+
519+
public void testExtractOptionalList_AddsException_WhenFieldContainsMixedTypeValues() {
520+
var validation = new ValidationException();
521+
validation.addValidationError("previous error");
522+
Map<String, Object> map = modifiableMap(Map.of("key", List.of(1, "a")));
523+
assertNull(extractOptionalList(map, "key", String.class, validation));
524+
assertFalse(validation.validationErrors().isEmpty());
525+
assertTrue(map.isEmpty());
526+
}
527+
472528
public void testExtractOptionalPositiveInt() {
473529
var validation = new ValidationException();
474530
validation.addValidationError("previous error");

0 commit comments

Comments
 (0)