Skip to content

Commit c8a5f0c

Browse files
Remove split merging and add long document unit test
1 parent 6f649fc commit c8a5f0c

File tree

3 files changed

+39
-66
lines changed

3 files changed

+39
-66
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,7 @@ private List<ChunkOffset> chunk(String input, List<String> separators, int maxCh
5555
return chunkWithBackupChunker(input, maxChunkSize, chunkOffset);
5656
}
5757

58-
var potentialChunks = mergeChunkOffsetsUpToMaxChunkSize(
59-
splitTextBySeparatorRegex(input, separators.get(separatorIndex)),
60-
maxChunkSize
61-
);
58+
var potentialChunks = splitTextBySeparatorRegex(input, separators.get(separatorIndex));
6259
var actualChunks = new ArrayList<ChunkOffset>();
6360
for (var potentialChunk : potentialChunks) {
6461
if (isChunkWithinMaxSize(potentialChunk, maxChunkSize)) {
@@ -113,33 +110,6 @@ private List<ChunkOffsetAndCount> splitTextBySeparatorRegex(String input, String
113110
return chunkOffsets;
114111
}
115112

116-
private List<ChunkOffsetAndCount> mergeChunkOffsetsUpToMaxChunkSize(List<ChunkOffsetAndCount> chunkOffsets, int maxChunkSize) {
117-
if (chunkOffsets.size() < 2) {
118-
return chunkOffsets;
119-
}
120-
121-
List<ChunkOffsetAndCount> mergedOffsetsAndCounts = new ArrayList<>();
122-
var mergedChunk = chunkOffsets.getFirst();
123-
for (int i = 1; i < chunkOffsets.size(); i++) {
124-
var chunkOffsetAndCountToMerge = chunkOffsets.get(i);
125-
var potentialMergedChunk = new ChunkOffsetAndCount(
126-
new ChunkOffset(mergedChunk.chunkOffset.start(), chunkOffsetAndCountToMerge.chunkOffset.end()),
127-
mergedChunk.wordCount + chunkOffsetAndCountToMerge.wordCount
128-
);
129-
if (isChunkWithinMaxSize(potentialMergedChunk, maxChunkSize)) {
130-
mergedChunk = potentialMergedChunk;
131-
} else {
132-
mergedOffsetsAndCounts.add(mergedChunk);
133-
mergedChunk = chunkOffsets.get(i);
134-
}
135-
136-
if (i == chunkOffsets.size() - 1) {
137-
mergedOffsetsAndCounts.add(mergedChunk);
138-
}
139-
}
140-
return mergedOffsetsAndCounts;
141-
}
142-
143113
private List<ChunkOffset> chunkWithBackupChunker(String input, int maxChunkSize, int chunkOffset) {
144114
var chunks = new SentenceBoundaryChunker().chunk(input, new SentenceBoundaryChunkingSettings(maxChunkSize, 0));
145115
var chunksWithOffsets = new ArrayList<ChunkOffset>();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,17 @@ public static SeparatorSet fromString(String name) {
2727
public List<String> getSeparators() {
2828
return switch (this) {
2929
case PLAINTEXT -> List.of("\n\n", "\n");
30-
case MARKDOWN -> List.of("\n#{1,6} ", "^(?!\\s*$).*\\n*{3,}\\n", "^(?!\\s*$).*\\n-{3,}\\n", "^(?!\\s*$).*\\n_{3,}\\n");
30+
case MARKDOWN -> List.of(
31+
"\n# ",
32+
"\n## ",
33+
"\n### ",
34+
"\n#### ",
35+
"\n##### ",
36+
"\n###### ",
37+
"^(?!\\s*$).*\\n*{3,}\\n",
38+
"^(?!\\s*$).*\\n-{3,}\\n",
39+
"^(?!\\s*$).*\\n_{3,}\\n"
40+
);
3141
};
3242
}
3343
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.inference.ChunkingSettings;
1111
import org.elasticsearch.test.ESTestCase;
1212

13+
import java.util.ArrayList;
1314
import java.util.List;
1415

1516
public class RecursiveChunkerTests extends ESTestCase {
@@ -45,7 +46,7 @@ public void testChunkInputShorterThanMaxChunkSize() {
4546
assertExpectedChunksGenerated(input, settings, List.of(new Chunker.ChunkOffset(0, input.length())));
4647
}
4748

48-
public void testChunkInputRequiresOneSplitWithNoMerges() {
49+
public void testChunkInputRequiresOneSplit() {
4950
List<String> separators = generateRandomSeparators();
5051
RecursiveChunkingSettings settings = generateChunkingSettings(10, separators);
5152
String input = generateTestText(2, List.of(separators.getFirst()));
@@ -57,23 +58,7 @@ public void testChunkInputRequiresOneSplitWithNoMerges() {
5758
);
5859
}
5960

60-
public void testChunkInputRequiresOneSplitWithMerges() {
61-
List<String> separators = generateRandomSeparators();
62-
RecursiveChunkingSettings settings = generateChunkingSettings(20, separators);
63-
String input = generateTestText(3, List.of(separators.getFirst(), separators.getFirst()));
64-
65-
var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.getFirst().length();
66-
assertExpectedChunksGenerated(
67-
input,
68-
settings,
69-
List.of(
70-
new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd),
71-
new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, input.length())
72-
)
73-
);
74-
}
75-
76-
public void testChunkInputRequiresMultipleSplitsWithNoMerges() {
61+
public void testChunkInputRequiresMultipleSplits() {
7762
var separators = generateRandomSeparators();
7863
RecursiveChunkingSettings settings = generateChunkingSettings(15, separators);
7964
String input = generateTestText(4, List.of(separators.get(1), separators.getFirst(), separators.get(1)));
@@ -93,22 +78,6 @@ public void testChunkInputRequiresMultipleSplitsWithNoMerges() {
9378
);
9479
}
9580

96-
public void testChunkInputRequiresMultipleSplitsWithMerges() {
97-
var separators = generateRandomSeparators();
98-
RecursiveChunkingSettings settings = generateChunkingSettings(25, separators);
99-
String input = generateTestText(4, List.of(separators.get(1), separators.getFirst(), separators.get(1)));
100-
101-
var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.get(1).length();
102-
assertExpectedChunksGenerated(
103-
input,
104-
settings,
105-
List.of(
106-
new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd),
107-
new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, input.length())
108-
)
109-
);
110-
}
111-
11281
public void testChunkInputRequiresBackupChunkingStrategy() {
11382
var separators = generateRandomSeparators();
11483
RecursiveChunkingSettings settings = generateChunkingSettings(10, separators);
@@ -149,6 +118,30 @@ public void testChunkWithRegexSeparator() {
149118
);
150119
}
151120

121+
public void testChunkLongDocument() {
122+
int numSentences = randomIntBetween(50, 100);
123+
List<String> separators = generateRandomSeparators();
124+
List<String> splittersAfterSentences = new ArrayList<>();
125+
for (int i = 0; i < numSentences - 1; i++) {
126+
splittersAfterSentences.add(randomFrom(separators));
127+
}
128+
RecursiveChunkingSettings settings = generateChunkingSettings(15, separators);
129+
String input = generateTestText(numSentences, splittersAfterSentences);
130+
131+
List<Chunker.ChunkOffset> expectedChunks = new ArrayList<>();
132+
int currentOffset = 0;
133+
for (int i = 0; i < numSentences; i++) {
134+
int chunkLength = TEST_SENTENCE.length();
135+
if (i > 0) {
136+
chunkLength += splittersAfterSentences.get(i - 1).length();
137+
}
138+
expectedChunks.add(new Chunker.ChunkOffset(currentOffset, currentOffset + chunkLength));
139+
currentOffset += chunkLength;
140+
}
141+
142+
assertExpectedChunksGenerated(input, settings, expectedChunks);
143+
}
144+
152145
private void assertExpectedChunksGenerated(String input, RecursiveChunkingSettings settings, List<Chunker.ChunkOffset> expectedChunks) {
153146
RecursiveChunker chunker = new RecursiveChunker();
154147
List<Chunker.ChunkOffset> chunks = chunker.chunk(input, settings);

0 commit comments

Comments
 (0)