Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
* The maximum chunk size is measured in words and controlled
* by {@code maxNumberWordsPerChunk}. For each separator the chunker will go through the following process:
* 1. Split the text on each regex match of the separator.
* 2. Merge consecutive chunks when it is possible to do so without exceeding the max chunk size.
* 3. For each chunk after the merge:
* 2. For each chunk after the merge:
* 1. Return it if it is within the maximum chunk size.
* 2. Repeat the process using the next separator in the list if the chunk exceeds the maximum chunk size.
* If there are no more separators left to try, run the {@code SentenceBoundaryChunker} with the provided
Expand All @@ -38,40 +37,36 @@ public RecursiveChunker() {
@Override
public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
if (chunkingSettings instanceof RecursiveChunkingSettings recursiveChunkingSettings) {
return chunk(input, recursiveChunkingSettings.getSeparators(), recursiveChunkingSettings.getMaxChunkSize(), 0, 0);
return chunk(
input,
new ChunkOffset(0, input.length()),
recursiveChunkingSettings.getSeparators(),
recursiveChunkingSettings.getMaxChunkSize(),
0
);
} else {
throw new IllegalArgumentException(
Strings.format("RecursiveChunker can't use ChunkingSettings with strategy [%s]", chunkingSettings.getChunkingStrategy())
);
}
}

private List<ChunkOffset> chunk(String input, List<String> separators, int maxChunkSize, int separatorIndex, int chunkOffset) {
if (input.length() < 2 || isChunkWithinMaxSize(buildChunkOffsetAndCount(input, 0, input.length()), maxChunkSize)) {
return List.of(new ChunkOffset(chunkOffset, chunkOffset + input.length()));
private List<ChunkOffset> chunk(String input, ChunkOffset offset, List<String> separators, int maxChunkSize, int separatorIndex) {
if (offset.start() == offset.end() || isChunkWithinMaxSize(buildChunkOffsetAndCount(input, offset), maxChunkSize)) {
return List.of(offset);
}

if (separatorIndex > separators.size() - 1) {
return chunkWithBackupChunker(input, maxChunkSize, chunkOffset);
return chunkWithBackupChunker(input, offset, maxChunkSize);
}

var potentialChunks = splitTextBySeparatorRegex(input, separators.get(separatorIndex));
var potentialChunks = splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex));
var actualChunks = new ArrayList<ChunkOffset>();
for (var potentialChunk : potentialChunks) {
if (isChunkWithinMaxSize(potentialChunk, maxChunkSize)) {
actualChunks.add(
new ChunkOffset(chunkOffset + potentialChunk.chunkOffset.start(), chunkOffset + potentialChunk.chunkOffset.end())
);
actualChunks.add(potentialChunk.chunkOffset());
} else {
actualChunks.addAll(
chunk(
input.substring(potentialChunk.chunkOffset.start(), potentialChunk.chunkOffset.end()),
separators,
maxChunkSize,
separatorIndex + 1,
chunkOffset + potentialChunk.chunkOffset.start()
)
);
actualChunks.addAll(chunk(input, potentialChunk.chunkOffset(), separators, maxChunkSize, separatorIndex + 1));
}
}

Expand All @@ -82,39 +77,46 @@ private boolean isChunkWithinMaxSize(ChunkOffsetAndCount chunkOffsetAndCount, in
return chunkOffsetAndCount.wordCount <= maxChunkSize;
}

private ChunkOffsetAndCount buildChunkOffsetAndCount(String fullText, int chunkStart, int chunkEnd) {
var chunkOffset = new ChunkOffset(chunkStart, chunkEnd);

private ChunkOffsetAndCount buildChunkOffsetAndCount(String fullText, ChunkOffset offset) {
wordIterator.setText(fullText);
return new ChunkOffsetAndCount(chunkOffset, ChunkerUtils.countWords(chunkStart, chunkEnd, wordIterator));
return new ChunkOffsetAndCount(offset, ChunkerUtils.countWords(offset.start(), offset.end(), wordIterator));
}

private List<ChunkOffsetAndCount> splitTextBySeparatorRegex(String input, String separatorRegex) {
var pattern = Pattern.compile(separatorRegex);
private List<ChunkOffsetAndCount> splitTextBySeparatorRegex(String input, ChunkOffset offset, String separatorRegex) {
var pattern = Pattern.compile(separatorRegex, Pattern.MULTILINE);
var matcher = pattern.matcher(input);

var chunkOffsets = new ArrayList<ChunkOffsetAndCount>();
int chunkStart = 0;
while (matcher.find()) {
int chunkStart = offset.start();
int searchStart = offset.start();
while (matcher.find(searchStart)) {
var chunkEnd = matcher.start();
if (chunkEnd >= offset.end()) {
break; // No more matches within the chunk offset
}

if (chunkStart < chunkEnd) {
chunkOffsets.add(buildChunkOffsetAndCount(input, chunkStart, chunkEnd));
chunkOffsets.add(buildChunkOffsetAndCount(input, new ChunkOffset(chunkStart, chunkEnd)));
}
chunkStart = matcher.start();
chunkStart = chunkEnd;
searchStart = matcher.end();
}

if (chunkStart < input.length()) {
chunkOffsets.add(buildChunkOffsetAndCount(input, chunkStart, input.length()));
if (chunkStart < offset.end()) {
chunkOffsets.add(buildChunkOffsetAndCount(input, new ChunkOffset(chunkStart, offset.end())));
}

return chunkOffsets;
}

private List<ChunkOffset> chunkWithBackupChunker(String input, int maxChunkSize, int chunkOffset) {
var chunks = new SentenceBoundaryChunker().chunk(input, new SentenceBoundaryChunkingSettings(maxChunkSize, 0));
private List<ChunkOffset> chunkWithBackupChunker(String input, ChunkOffset offset, int maxChunkSize) {
var chunks = new SentenceBoundaryChunker().chunk(
input.substring(offset.start(), offset.end()),
new SentenceBoundaryChunkingSettings(maxChunkSize, 0)
);
var chunksWithOffsets = new ArrayList<ChunkOffset>();
for (var chunk : chunks) {
chunksWithOffsets.add(new ChunkOffset(chunk.start() + chunkOffset, chunk.end() + chunkOffset));
chunksWithOffsets.add(new ChunkOffset(chunk.start() + offset.start(), chunk.end() + offset.start()));
}
return chunksWithOffsets;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,16 @@ public static SeparatorSet fromString(String name) {

public List<String> getSeparators() {
return switch (this) {
case PLAINTEXT -> List.of("\n\n", "\n");
case PLAINTEXT -> List.of("(?<!\\n)\\n\\n(?!\\n)", "(?<!\\n)\\n(?!\\n)");
case MARKDOWN -> List.of(
"\n# ",
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
"^(?!\\s*$).*\\n*{3,}\\n",
"^(?!\\s*$).*\\n-{3,}\\n",
"^(?!\\s*$).*\\n_{3,}\\n"
"\n^(?!\\s*$).*\\n-{1,}\\n",
"\n^(?!\\s*$).*\\n={1,}\\n"
);
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

public class RecursiveChunkerTests extends ESTestCase {

private final List<String> TEST_SEPARATORS = List.of("\n\n", "\n", "\f", "\t", "#");
private final List<String> TEST_SEPARATORS = List.of("\n", "\f", "\t", "#");
private final String TEST_SENTENCE = "This is a test sentence that has ten total words. ";

public void testChunkWithInvalidChunkingSettings() {
Expand Down Expand Up @@ -142,6 +142,45 @@ public void testChunkLongDocument() {
assertExpectedChunksGenerated(input, settings, expectedChunks);
}

public void testMarkdownChunking() {
int numSentences = randomIntBetween(10, 50);
List<String> separators = SeparatorSet.MARKDOWN.getSeparators();
List<String> validHeaders = List.of(
"# Header\n",
"## Header\n",
"### Header\n",
"#### Header\n",
"##### Header\n",
"###### Header\n",
"Header\n-\n",
"Header\n=\n"
);
List<String> validSplittersAfterSentences = validHeaders.stream().map(header -> "\n" + header).toList();
List<String> splittersAfterSentences = new ArrayList<>();
for (int i = 0; i < numSentences - 1; i++) {
splittersAfterSentences.add(randomFrom(validSplittersAfterSentences));
}
RecursiveChunkingSettings settings = generateChunkingSettings(15, separators);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the small chunk size the generated chunks will never contain more than 1 sentence. Can you structure the test so that some chunks contain multiple heading sections.

For example if, if chunks size was 100 words and given the document

# heading1.1
## heading1.2.1 
TEST_SENTENCE * 3

## heading1.2.2 
TEST_SENTENCE * 2

# heading2.1
## heading2.2.1 
TEST_SENTENCE * 9

## heading2.2.2 
TEST_SENTENCE 

### heading2.3.1
TEST_SENTENCE 

### heading2.3.2
TEST_SENTENCE  

In this case, given an ordered list of separators, I would expect # heading1.1 -> # heading2.1 to be a single chunks. Then 2 more chunks for heading2.2.1 and heading2.2.2

Please add tests on longer documents that capture the hierarchical nature of the chunker

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline with Dave. Adding this into the existing long document tests that randomly generate a document would require essentially re-writing the chunking logic into the testing file to generate the expected chunk limits. We've instead decided it makes sense to add a new test with a smaller fixed length document to cover this case.

String input = generateTestText(numSentences, splittersAfterSentences);
String leadingHeader = randomFrom(validHeaders);
input = leadingHeader + input;

List<Chunker.ChunkOffset> expectedChunks = new ArrayList<>();
int currentOffset = 0;
for (int i = 0; i < numSentences; i++) {
int chunkLength = TEST_SENTENCE.length();
if (i == 0) {
chunkLength += leadingHeader.length();
} else {
chunkLength += splittersAfterSentences.get(i - 1).length();
}
expectedChunks.add(new Chunker.ChunkOffset(currentOffset, currentOffset + chunkLength));
currentOffset += chunkLength;
}

assertExpectedChunksGenerated(input, settings, expectedChunks);
}

private void assertExpectedChunksGenerated(String input, RecursiveChunkingSettings settings, List<Chunker.ChunkOffset> expectedChunks) {
RecursiveChunker chunker = new RecursiveChunker();
List<Chunker.ChunkOffset> chunks = chunker.chunk(input, settings);
Expand Down
Loading