diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java index 5d04df5d2e1d5..c5e4abd3648c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java @@ -12,7 +12,7 @@ public enum ChunkingSettingsOptions { MAX_CHUNK_SIZE("max_chunk_size"), OVERLAP("overlap"), SENTENCE_OVERLAP("sentence_overlap"), - SEPARATOR_SET("separator_set"), + SEPARATOR_GROUP("separator_group"), SEPARATORS("separators"); private final String chunkingSettingsOption; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java index 690a3d8ff0efe..c68dc3b216744 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java @@ -60,7 +60,10 @@ private List chunk(String input, ChunkOffset offset, List s return chunkWithBackupChunker(input, offset, maxChunkSize); } - var potentialChunks = splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex)); + var potentialChunks = mergeChunkOffsetsUpToMaxChunkSize( + splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex)), + maxChunkSize + ); var actualChunks = new ArrayList(); for (var potentialChunk : potentialChunks) { if (isChunkWithinMaxSize(potentialChunk, maxChunkSize)) { @@ -104,6 +107,33 @@ private List splitTextBySeparatorRegex(String input, ChunkO return chunkOffsets; } + private List mergeChunkOffsetsUpToMaxChunkSize(List chunkOffsets, int maxChunkSize) { + if (chunkOffsets.size() < 2) { + return chunkOffsets; + } + + List mergedOffsetsAndCounts = new ArrayList<>(); + var mergedChunk = chunkOffsets.getFirst(); + for (int i = 1; i < chunkOffsets.size(); i++) { + var chunkOffsetAndCountToMerge = chunkOffsets.get(i); + var potentialMergedChunk = new ChunkOffsetAndCount( + new ChunkOffset(mergedChunk.chunkOffset.start(), chunkOffsetAndCountToMerge.chunkOffset.end()), + mergedChunk.wordCount + chunkOffsetAndCountToMerge.wordCount + ); + if (isChunkWithinMaxSize(potentialMergedChunk, maxChunkSize)) { + mergedChunk = potentialMergedChunk; + } else { + mergedOffsetsAndCounts.add(mergedChunk); + mergedChunk = chunkOffsets.get(i); + } + + if (i == chunkOffsets.size() - 1) { + mergedOffsetsAndCounts.add(mergedChunk); + } + } + return mergedOffsetsAndCounts; + } + private List chunkWithBackupChunker(String input, ChunkOffset offset, int maxChunkSize) { var chunks = new SentenceBoundaryChunker().chunk( input.substring(offset.start(), offset.end()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java index c368e1bb0c255..611736ceb4213 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java @@ -36,7 +36,7 @@ public class RecursiveChunkingSettings implements ChunkingSettings { private static final Set VALID_KEYS = Set.of( ChunkingSettingsOptions.STRATEGY.toString(), ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), - ChunkingSettingsOptions.SEPARATOR_SET.toString(), + ChunkingSettingsOptions.SEPARATOR_GROUP.toString(), ChunkingSettingsOptions.SEPARATORS.toString() ); @@ -45,7 +45,7 @@ public class RecursiveChunkingSettings implements ChunkingSettings { public RecursiveChunkingSettings(int maxChunkSize, List separators) { this.maxChunkSize = maxChunkSize; - this.separators = separators == null ? SeparatorSet.PLAINTEXT.getSeparators() : separators; + this.separators = separators == null ? SeparatorGroup.PLAINTEXT.getSeparators() : separators; } public RecursiveChunkingSettings(StreamInput in) throws IOException { @@ -72,12 +72,12 @@ public static RecursiveChunkingSettings fromMap(Map map) { validationException ); - SeparatorSet separatorSet = ServiceUtils.extractOptionalEnum( + SeparatorGroup separatorGroup = ServiceUtils.extractOptionalEnum( map, - ChunkingSettingsOptions.SEPARATOR_SET.toString(), + ChunkingSettingsOptions.SEPARATOR_GROUP.toString(), ModelConfigurations.CHUNKING_SETTINGS, - SeparatorSet::fromString, - EnumSet.allOf(SeparatorSet.class), + SeparatorGroup::fromString, + EnumSet.allOf(SeparatorGroup.class), validationException ); @@ -88,12 +88,12 @@ public static RecursiveChunkingSettings fromMap(Map map) { validationException ); - if (separators != null && separatorSet != null) { + if (separators != null && separatorGroup != null) { validationException.addValidationError("Recursive chunking settings can not have both separators and separator_set"); } - if (separatorSet != null) { - separators = separatorSet.getSeparators(); + if (separatorGroup != null) { + separators = separatorGroup.getSeparators(); } else if (separators != null && separators.isEmpty()) { validationException.addValidationError("Recursive chunking settings can not have an empty list of separators"); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorGroup.java similarity index 89% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorGroup.java index 61b997b8d17a9..cafd3b08ccf9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorGroup.java @@ -10,17 +10,17 @@ import java.util.List; import java.util.Locale; -public enum SeparatorSet { +public enum SeparatorGroup { PLAINTEXT("plaintext"), MARKDOWN("markdown"); private final String name; - SeparatorSet(String name) { + SeparatorGroup(String name) { this.name = name; } - public static SeparatorSet fromString(String name) { + public static SeparatorGroup fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java index baa8429ae3c78..1cb90b11995fc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java @@ -46,7 +46,7 @@ public void testChunkInputShorterThanMaxChunkSize() { assertExpectedChunksGenerated(input, settings, List.of(new Chunker.ChunkOffset(0, input.length()))); } - public void testChunkInputRequiresOneSplit() { + public void testChunkInputRequiresOneSplitWithoutMerges() { List separators = generateRandomSeparators(); RecursiveChunkingSettings settings = generateChunkingSettings(10, separators); String input = generateTestText(2, List.of(separators.getFirst())); @@ -58,7 +58,23 @@ public void testChunkInputRequiresOneSplit() { ); } - public void testChunkInputRequiresMultipleSplits() { + public void testChunkInputRequiresOneSplitWithMerges() { + List separators = generateRandomSeparators(); + RecursiveChunkingSettings settings = generateChunkingSettings(20, separators); + String input = generateTestText(3, List.of(separators.getFirst(), separators.getFirst())); + + var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.getFirst().length(); + assertExpectedChunksGenerated( + input, + settings, + List.of( + new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd), + new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, input.length()) + ) + ); + } + + public void testChunkInputRequiresMultipleSplitsWithoutMerges() { var separators = generateRandomSeparators(); RecursiveChunkingSettings settings = generateChunkingSettings(15, separators); String input = generateTestText(4, List.of(separators.get(1), separators.getFirst(), separators.get(1))); @@ -78,6 +94,22 @@ public void testChunkInputRequiresMultipleSplits() { ); } + public void testChunkInputRequiresMultipleSplitsWithMerges() { + var separators = generateRandomSeparators(); + RecursiveChunkingSettings settings = generateChunkingSettings(25, separators); + String input = generateTestText(4, List.of(separators.get(1), separators.getFirst(), separators.get(1))); + + var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.get(1).length(); + assertExpectedChunksGenerated( + input, + settings, + List.of( + new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd), + new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, input.length()) + ) + ); + } + public void testChunkInputDoesNotSplitWhenNoLongerExceedingMaxChunkSize() { var separators = randomSubsetOf(3, TEST_SEPARATORS); RecursiveChunkingSettings settings = generateChunkingSettings(25, separators); @@ -165,7 +197,7 @@ public void testChunkLongDocument() { public void testMarkdownChunking() { int numSentences = randomIntBetween(10, 50); - List separators = SeparatorSet.MARKDOWN.getSeparators(); + List separators = SeparatorGroup.MARKDOWN.getSeparators(); List validHeaders = List.of( "# Header\n", "## Header\n", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java index 40f14e88d2558..f833aa09b1aee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java @@ -32,15 +32,15 @@ public void testFromMapValidSettingsWithSeparators() { assertEquals(separators, settings.getSeparators()); } - public void testFromMapValidSettingsWithSeparatorSet() { + public void testFromMapValidSettingsWithSeparatorGroup() { var maxChunkSize = randomIntBetween(10, 300); - var separatorSet = randomFrom(SeparatorSet.values()); - Map validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.of(separatorSet.name()), Optional.empty()); + var separatorGroup = randomFrom(SeparatorGroup.values()); + Map validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.of(separatorGroup.name()), Optional.empty()); RecursiveChunkingSettings settings = RecursiveChunkingSettings.fromMap(validSettings); assertEquals(maxChunkSize, settings.getMaxChunkSize()); - assertEquals(separatorSet.getSeparators(), settings.getSeparators()); + assertEquals(separatorGroup.getSeparators(), settings.getSeparators()); } public void testFromMapMaxChunkSizeTooSmall() { @@ -55,7 +55,7 @@ public void testFromMapMaxChunkSizeTooLarge() { assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); } - public void testFromMapInvalidSeparatorSet() { + public void testFromMapInvalidSeparatorGroup() { Map invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.of("invalid"), Optional.empty()); assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); @@ -68,7 +68,7 @@ public void testFromMapInvalidSettingKey() { assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); } - public void testFromMapBothSeparatorsAndSeparatorSet() { + public void testFromMapBothSeparatorsAndSeparatorGroup() { Map invalidSettings = buildChunkingSettingsMap( randomIntBetween(10, 300), Optional.of("default"), @@ -86,13 +86,13 @@ public void testFromMapEmptySeparators() { private Map buildChunkingSettingsMap( int maxChunkSize, - Optional separatorSet, + Optional separatorGroup, Optional> separators ) { Map settingsMap = new HashMap<>(); settingsMap.put(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.RECURSIVE.toString()); settingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); - separatorSet.ifPresent(s -> settingsMap.put(ChunkingSettingsOptions.SEPARATOR_SET.toString(), s)); + separatorGroup.ifPresent(s -> settingsMap.put(ChunkingSettingsOptions.SEPARATOR_GROUP.toString(), s)); separators.ifPresent(strings -> settingsMap.put(ChunkingSettingsOptions.SEPARATORS.toString(), strings)); return settingsMap; }