Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public enum ChunkingSettingsOptions {
MAX_CHUNK_SIZE("max_chunk_size"),
OVERLAP("overlap"),
SENTENCE_OVERLAP("sentence_overlap"),
SEPARATOR_SET("separator_set"),
SEPARATOR_GROUP("separator_group"),
SEPARATORS("separators");

private final String chunkingSettingsOption;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ private List<ChunkOffset> chunk(String input, ChunkOffset offset, List<String> s
return chunkWithBackupChunker(input, offset, maxChunkSize);
}

var potentialChunks = splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex));
var potentialChunks = mergeChunkOffsetsUpToMaxChunkSize(
splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex)),
maxChunkSize
);
var actualChunks = new ArrayList<ChunkOffset>();
for (var potentialChunk : potentialChunks) {
if (isChunkWithinMaxSize(potentialChunk, maxChunkSize)) {
Expand Down Expand Up @@ -104,6 +107,33 @@ private List<ChunkOffsetAndCount> splitTextBySeparatorRegex(String input, ChunkO
return chunkOffsets;
}

private List<ChunkOffsetAndCount> mergeChunkOffsetsUpToMaxChunkSize(List<ChunkOffsetAndCount> chunkOffsets, int maxChunkSize) {
if (chunkOffsets.size() < 2) {
return chunkOffsets;
}

List<ChunkOffsetAndCount> mergedOffsetsAndCounts = new ArrayList<>();
var mergedChunk = chunkOffsets.getFirst();
for (int i = 1; i < chunkOffsets.size(); i++) {
var chunkOffsetAndCountToMerge = chunkOffsets.get(i);
var potentialMergedChunk = new ChunkOffsetAndCount(
new ChunkOffset(mergedChunk.chunkOffset.start(), chunkOffsetAndCountToMerge.chunkOffset.end()),
mergedChunk.wordCount + chunkOffsetAndCountToMerge.wordCount
);
if (isChunkWithinMaxSize(potentialMergedChunk, maxChunkSize)) {
mergedChunk = potentialMergedChunk;
} else {
mergedOffsetsAndCounts.add(mergedChunk);
mergedChunk = chunkOffsets.get(i);
}

if (i == chunkOffsets.size() - 1) {
mergedOffsetsAndCounts.add(mergedChunk);
}
}
return mergedOffsetsAndCounts;
}

private List<ChunkOffset> chunkWithBackupChunker(String input, ChunkOffset offset, int maxChunkSize) {
var chunks = new SentenceBoundaryChunker().chunk(
input.substring(offset.start(), offset.end()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class RecursiveChunkingSettings implements ChunkingSettings {
private static final Set<String> VALID_KEYS = Set.of(
ChunkingSettingsOptions.STRATEGY.toString(),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
ChunkingSettingsOptions.SEPARATOR_SET.toString(),
ChunkingSettingsOptions.SEPARATOR_GROUP.toString(),
ChunkingSettingsOptions.SEPARATORS.toString()
);

Expand All @@ -45,7 +45,7 @@ public class RecursiveChunkingSettings implements ChunkingSettings {

public RecursiveChunkingSettings(int maxChunkSize, List<String> separators) {
this.maxChunkSize = maxChunkSize;
this.separators = separators == null ? SeparatorSet.PLAINTEXT.getSeparators() : separators;
this.separators = separators == null ? SeparatorGroup.PLAINTEXT.getSeparators() : separators;
}

public RecursiveChunkingSettings(StreamInput in) throws IOException {
Expand All @@ -72,12 +72,12 @@ public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
validationException
);

SeparatorSet separatorSet = ServiceUtils.extractOptionalEnum(
SeparatorGroup separatorGroup = ServiceUtils.extractOptionalEnum(
map,
ChunkingSettingsOptions.SEPARATOR_SET.toString(),
ChunkingSettingsOptions.SEPARATOR_GROUP.toString(),
ModelConfigurations.CHUNKING_SETTINGS,
SeparatorSet::fromString,
EnumSet.allOf(SeparatorSet.class),
SeparatorGroup::fromString,
EnumSet.allOf(SeparatorGroup.class),
validationException
);

Expand All @@ -88,12 +88,12 @@ public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
validationException
);

if (separators != null && separatorSet != null) {
if (separators != null && separatorGroup != null) {
validationException.addValidationError("Recursive chunking settings can not have both separators and separator_set");
}

if (separatorSet != null) {
separators = separatorSet.getSeparators();
if (separatorGroup != null) {
separators = separatorGroup.getSeparators();
} else if (separators != null && separators.isEmpty()) {
validationException.addValidationError("Recursive chunking settings can not have an empty list of separators");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
import java.util.List;
import java.util.Locale;

public enum SeparatorSet {
public enum SeparatorGroup {
PLAINTEXT("plaintext"),
MARKDOWN("markdown");

private final String name;

SeparatorSet(String name) {
SeparatorGroup(String name) {
this.name = name;
}

public static SeparatorSet fromString(String name) {
public static SeparatorGroup fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void testChunkInputShorterThanMaxChunkSize() {
assertExpectedChunksGenerated(input, settings, List.of(new Chunker.ChunkOffset(0, input.length())));
}

public void testChunkInputRequiresOneSplit() {
public void testChunkInputRequiresOneSplitWithoutMerges() {
List<String> separators = generateRandomSeparators();
RecursiveChunkingSettings settings = generateChunkingSettings(10, separators);
String input = generateTestText(2, List.of(separators.getFirst()));
Expand All @@ -58,7 +58,23 @@ public void testChunkInputRequiresOneSplit() {
);
}

public void testChunkInputRequiresMultipleSplits() {
public void testChunkInputRequiresOneSplitWithMerges() {
List<String> separators = generateRandomSeparators();
RecursiveChunkingSettings settings = generateChunkingSettings(20, separators);
String input = generateTestText(3, List.of(separators.getFirst(), separators.getFirst()));

var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.getFirst().length();
assertExpectedChunksGenerated(
input,
settings,
List.of(
new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd),
new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, input.length())
)
);
}

public void testChunkInputRequiresMultipleSplitsWithoutMerges() {
var separators = generateRandomSeparators();
RecursiveChunkingSettings settings = generateChunkingSettings(15, separators);
String input = generateTestText(4, List.of(separators.get(1), separators.getFirst(), separators.get(1)));
Expand All @@ -78,6 +94,22 @@ public void testChunkInputRequiresMultipleSplits() {
);
}

public void testChunkInputRequiresMultipleSplitsWithMerges() {
var separators = generateRandomSeparators();
RecursiveChunkingSettings settings = generateChunkingSettings(25, separators);
String input = generateTestText(4, List.of(separators.get(1), separators.getFirst(), separators.get(1)));

var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.get(1).length();
assertExpectedChunksGenerated(
input,
settings,
List.of(
new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd),
new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, input.length())
)
);
}

public void testChunkInputDoesNotSplitWhenNoLongerExceedingMaxChunkSize() {
var separators = randomSubsetOf(3, TEST_SEPARATORS);
RecursiveChunkingSettings settings = generateChunkingSettings(25, separators);
Expand Down Expand Up @@ -165,7 +197,7 @@ public void testChunkLongDocument() {

public void testMarkdownChunking() {
int numSentences = randomIntBetween(10, 50);
List<String> separators = SeparatorSet.MARKDOWN.getSeparators();
List<String> separators = SeparatorGroup.MARKDOWN.getSeparators();
List<String> validHeaders = List.of(
"# Header\n",
"## Header\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ public void testFromMapValidSettingsWithSeparators() {
assertEquals(separators, settings.getSeparators());
}

public void testFromMapValidSettingsWithSeparatorSet() {
public void testFromMapValidSettingsWithSeparatorGroup() {
var maxChunkSize = randomIntBetween(10, 300);
var separatorSet = randomFrom(SeparatorSet.values());
Map<String, Object> validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.of(separatorSet.name()), Optional.empty());
var separatorGroup = randomFrom(SeparatorGroup.values());
Map<String, Object> validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.of(separatorGroup.name()), Optional.empty());

RecursiveChunkingSettings settings = RecursiveChunkingSettings.fromMap(validSettings);

assertEquals(maxChunkSize, settings.getMaxChunkSize());
assertEquals(separatorSet.getSeparators(), settings.getSeparators());
assertEquals(separatorGroup.getSeparators(), settings.getSeparators());
}

public void testFromMapMaxChunkSizeTooSmall() {
Expand All @@ -55,7 +55,7 @@ public void testFromMapMaxChunkSizeTooLarge() {
assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
}

public void testFromMapInvalidSeparatorSet() {
public void testFromMapInvalidSeparatorGroup() {
Map<String, Object> invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.of("invalid"), Optional.empty());

assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
Expand All @@ -68,7 +68,7 @@ public void testFromMapInvalidSettingKey() {
assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
}

public void testFromMapBothSeparatorsAndSeparatorSet() {
public void testFromMapBothSeparatorsAndSeparatorGroup() {
Map<String, Object> invalidSettings = buildChunkingSettingsMap(
randomIntBetween(10, 300),
Optional.of("default"),
Expand All @@ -86,13 +86,13 @@ public void testFromMapEmptySeparators() {

private Map<String, Object> buildChunkingSettingsMap(
int maxChunkSize,
Optional<String> separatorSet,
Optional<String> separatorGroup,
Optional<List<String>> separators
) {
Map<String, Object> settingsMap = new HashMap<>();
settingsMap.put(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.RECURSIVE.toString());
settingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize);
separatorSet.ifPresent(s -> settingsMap.put(ChunkingSettingsOptions.SEPARATOR_SET.toString(), s));
separatorGroup.ifPresent(s -> settingsMap.put(ChunkingSettingsOptions.SEPARATOR_GROUP.toString(), s));
separators.ifPresent(strings -> settingsMap.put(ChunkingSettingsOptions.SEPARATORS.toString(), strings));
return settingsMap;
}
Expand Down