diff --git a/docs/changelog/126866.yaml b/docs/changelog/126866.yaml new file mode 100644 index 0000000000000..ff2e9d2ce03cb --- /dev/null +++ b/docs/changelog/126866.yaml @@ -0,0 +1,5 @@ +pr: 126866 +summary: Add recursive chunker +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java b/server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java index 78404c1b409ee..995a844118290 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java @@ -16,6 +16,7 @@ public enum ChunkingStrategy { WORD("word"), SENTENCE("sentence"), + RECURSIVE("recursive"), NONE("none"); private final String chunkingStrategy; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 39bc7e9b01086..d7c1c81dca761 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings; +import org.elasticsearch.xpack.inference.chunking.RecursiveChunkingSettings; import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; @@ -567,6 +568,9 @@ private static void addChunkingSettingsNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java index f10748c2fec97..17d08c9e58634 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java @@ -19,6 +19,7 @@ public static Chunker fromChunkingStrategy(ChunkingStrategy chunkingStrategy) { case NONE -> NoopChunker.INSTANCE; case WORD -> new WordBoundaryChunker(); case SENTENCE -> new SentenceBoundaryChunker(); + case RECURSIVE -> new RecursiveChunker(); }; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerUtils.java new file mode 100644 index 0000000000000..4d391df48e3d6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerUtils.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import com.ibm.icu.text.BreakIterator; + +public class ChunkerUtils { + + // setText() should be applied before using this function. + static int countWords(int start, int end, BreakIterator wordIterator) { + assert start < end; + wordIterator.preceding(start); // start of the current word + + int boundary = wordIterator.current(); + int wordCount = 0; + while (boundary != BreakIterator.DONE && boundary <= end) { + int wordStatus = wordIterator.getRuleStatus(); + if (wordStatus != BreakIterator.WORD_NONE) { + wordCount++; + } + boundary = wordIterator.next(); + } + + return wordCount; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index b1bc5987eaa99..2f912d891ef60 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -48,6 +48,7 @@ public static ChunkingSettings fromMap(Map settings, boolean ret case NONE -> NoneChunkingSettings.INSTANCE; case WORD -> WordBoundaryChunkingSettings.fromMap(new HashMap<>(settings)); case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(new HashMap<>(settings)); + case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings)); }; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java index 93d435eb0b69f..5d04df5d2e1d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java @@ -11,7 +11,9 @@ public enum ChunkingSettingsOptions { STRATEGY("strategy"), MAX_CHUNK_SIZE("max_chunk_size"), OVERLAP("overlap"), - SENTENCE_OVERLAP("sentence_overlap"); + SENTENCE_OVERLAP("sentence_overlap"), + SEPARATOR_SET("separator_set"), + SEPARATORS("separators"); private final String chunkingSettingsOption; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java new file mode 100644 index 0000000000000..690a3d8ff0efe --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunker.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import com.ibm.icu.text.BreakIterator; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.ChunkingSettings; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +/** + * Split text into chunks recursively based on a list of separator regex strings. + * The maximum chunk size is measured in words and controlled + * by {@code maxNumberWordsPerChunk}. For each separator the chunker will go through the following process: + * 1. Split the text on each regex match of the separator. + * 2. For each chunk after the merge: + * 1. Return it if it is within the maximum chunk size. + * 2. Repeat the process using the next separator in the list if the chunk exceeds the maximum chunk size. + * If there are no more separators left to try, run the {@code SentenceBoundaryChunker} with the provided + * max chunk size and no overlaps. + */ +public class RecursiveChunker implements Chunker { + private final BreakIterator wordIterator; + + public RecursiveChunker() { + wordIterator = BreakIterator.getWordInstance(); + } + + @Override + public List chunk(String input, ChunkingSettings chunkingSettings) { + if (chunkingSettings instanceof RecursiveChunkingSettings recursiveChunkingSettings) { + return chunk( + input, + new ChunkOffset(0, input.length()), + recursiveChunkingSettings.getSeparators(), + recursiveChunkingSettings.getMaxChunkSize(), + 0 + ); + } else { + throw new IllegalArgumentException( + Strings.format("RecursiveChunker can't use ChunkingSettings with strategy [%s]", chunkingSettings.getChunkingStrategy()) + ); + } + } + + private List chunk(String input, ChunkOffset offset, List separators, int maxChunkSize, int separatorIndex) { + if (offset.start() == offset.end() || isChunkWithinMaxSize(buildChunkOffsetAndCount(input, offset), maxChunkSize)) { + return List.of(offset); + } + + if (separatorIndex > separators.size() - 1) { + return chunkWithBackupChunker(input, offset, maxChunkSize); + } + + var potentialChunks = splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex)); + var actualChunks = new ArrayList(); + for (var potentialChunk : potentialChunks) { + if (isChunkWithinMaxSize(potentialChunk, maxChunkSize)) { + actualChunks.add(potentialChunk.chunkOffset()); + } else { + actualChunks.addAll(chunk(input, potentialChunk.chunkOffset(), separators, maxChunkSize, separatorIndex + 1)); + } + } + + return actualChunks; + } + + private boolean isChunkWithinMaxSize(ChunkOffsetAndCount chunkOffsetAndCount, int maxChunkSize) { + return chunkOffsetAndCount.wordCount <= maxChunkSize; + } + + private ChunkOffsetAndCount buildChunkOffsetAndCount(String fullText, ChunkOffset offset) { + wordIterator.setText(fullText); + return new ChunkOffsetAndCount(offset, ChunkerUtils.countWords(offset.start(), offset.end(), wordIterator)); + } + + private List splitTextBySeparatorRegex(String input, ChunkOffset offset, String separatorRegex) { + var pattern = Pattern.compile(separatorRegex, Pattern.MULTILINE); + var matcher = pattern.matcher(input).region(offset.start(), offset.end()); + + var chunkOffsets = new ArrayList(); + int chunkStart = offset.start(); + while (matcher.find()) { + var chunkEnd = matcher.start(); + + if (chunkStart < chunkEnd) { + chunkOffsets.add(buildChunkOffsetAndCount(input, new ChunkOffset(chunkStart, chunkEnd))); + } + chunkStart = chunkEnd; + } + + if (chunkStart < offset.end()) { + chunkOffsets.add(buildChunkOffsetAndCount(input, new ChunkOffset(chunkStart, offset.end()))); + } + + return chunkOffsets; + } + + private List chunkWithBackupChunker(String input, ChunkOffset offset, int maxChunkSize) { + var chunks = new SentenceBoundaryChunker().chunk( + input.substring(offset.start(), offset.end()), + new SentenceBoundaryChunkingSettings(maxChunkSize, 0) + ); + var chunksWithOffsets = new ArrayList(); + for (var chunk : chunks) { + chunksWithOffsets.add(new ChunkOffset(chunk.start() + offset.start(), chunk.end() + offset.start())); + } + return chunksWithOffsets; + } + + private record ChunkOffsetAndCount(ChunkOffset chunkOffset, int wordCount) {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java new file mode 100644 index 0000000000000..c368e1bb0c255 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class RecursiveChunkingSettings implements ChunkingSettings { + public static final String NAME = "RecursiveChunkingSettings"; + private static final ChunkingStrategy STRATEGY = ChunkingStrategy.RECURSIVE; + private static final int MAX_CHUNK_SIZE_LOWER_LIMIT = 10; + private static final int MAX_CHUNK_SIZE_UPPER_LIMIT = 300; + + private static final Set VALID_KEYS = Set.of( + ChunkingSettingsOptions.STRATEGY.toString(), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + ChunkingSettingsOptions.SEPARATOR_SET.toString(), + ChunkingSettingsOptions.SEPARATORS.toString() + ); + + private final int maxChunkSize; + private final List separators; + + public RecursiveChunkingSettings(int maxChunkSize, List separators) { + this.maxChunkSize = maxChunkSize; + this.separators = separators == null ? SeparatorSet.PLAINTEXT.getSeparators() : separators; + } + + public RecursiveChunkingSettings(StreamInput in) throws IOException { + maxChunkSize = in.readInt(); + separators = in.readCollectionAsList(StreamInput::readString); + } + + public static RecursiveChunkingSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray(); + if (invalidSettings.length > 0) { + validationException.addValidationError( + Strings.format("Recursive chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings)) + ); + } + + Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerBetween( + map, + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + MAX_CHUNK_SIZE_LOWER_LIMIT, + MAX_CHUNK_SIZE_UPPER_LIMIT, + ModelConfigurations.CHUNKING_SETTINGS, + validationException + ); + + SeparatorSet separatorSet = ServiceUtils.extractOptionalEnum( + map, + ChunkingSettingsOptions.SEPARATOR_SET.toString(), + ModelConfigurations.CHUNKING_SETTINGS, + SeparatorSet::fromString, + EnumSet.allOf(SeparatorSet.class), + validationException + ); + + List separators = ServiceUtils.extractOptionalList( + map, + ChunkingSettingsOptions.SEPARATORS.toString(), + String.class, + validationException + ); + + if (separators != null && separatorSet != null) { + validationException.addValidationError("Recursive chunking settings can not have both separators and separator_set"); + } + + if (separatorSet != null) { + separators = separatorSet.getSeparators(); + } else if (separators != null && separators.isEmpty()) { + validationException.addValidationError("Recursive chunking settings can not have an empty list of separators"); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new RecursiveChunkingSettings(maxChunkSize, separators); + } + + public int getMaxChunkSize() { + return maxChunkSize; + } + + public List getSeparators() { + return separators; + } + + @Override + public ChunkingStrategy getChunkingStrategy() { + return STRATEGY; + } + + @Override + public Map asMap() { + return Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + STRATEGY.toString().toLowerCase(Locale.ROOT), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.SEPARATORS.toString(), + separators + ); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return null; // TODO: Add transport version + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(maxChunkSize); + out.writeCollection(separators, StreamOutput::writeString); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + { + builder.field(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY); + builder.field(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); + builder.field(ChunkingSettingsOptions.SEPARATORS.toString(), separators); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RecursiveChunkingSettings that = (RecursiveChunkingSettings) o; + return Objects.equals(maxChunkSize, that.maxChunkSize) && Objects.equals(separators, that.separators); + } + + @Override + public int hashCode() { + return Objects.hash(maxChunkSize, separators); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java index bf28e30074a9d..aa76c40085464 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java @@ -211,26 +211,7 @@ static int skipWords(int start, int numWords, BreakIterator wordIterator) { } private int countWords(int start, int end) { - return countWords(start, end, this.wordIterator); - } - - // Exposed for testing. wordIterator should have had - // setText() applied before using this function. - static int countWords(int start, int end, BreakIterator wordIterator) { - assert start < end; - wordIterator.preceding(start); // start of the current word - - int boundary = wordIterator.current(); - int wordCount = 0; - while (boundary != BreakIterator.DONE && boundary <= end) { - int wordStatus = wordIterator.getRuleStatus(); - if (wordStatus != BreakIterator.WORD_NONE) { - wordCount++; - } - boundary = wordIterator.next(); - } - - return wordCount; + return ChunkerUtils.countWords(start, end, this.wordIterator); } private static int overlapForChunkSize(int chunkSize) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java new file mode 100644 index 0000000000000..61b997b8d17a9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SeparatorSet.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import java.util.List; +import java.util.Locale; + +public enum SeparatorSet { + PLAINTEXT("plaintext"), + MARKDOWN("markdown"); + + private final String name; + + SeparatorSet(String name) { + this.name = name; + } + + public static SeparatorSet fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public List getSeparators() { + return switch (this) { + case PLAINTEXT -> List.of("(? List.of( + "\n# ", + "\n## ", + "\n### ", + "\n#### ", + "\n##### ", + "\n###### ", + "\n^(?!\\s*$).*\\n-{1,}\\n", + "\n^(?!\\s*$).*\\n={1,}\\n" + ); + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index dfdabcab13a9c..f8a77c4da2e85 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -424,6 +424,35 @@ public static String extractOptionalString( return optionalField; } + @SuppressWarnings("unchecked") + public static List extractOptionalList( + Map map, + String settingName, + Class type, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + var optionalField = ServiceUtils.removeAsType(map, settingName, List.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + if (optionalField != null) { + for (Object o : optionalField) { + if (o.getClass().equals(type) == false) { + validationException.addValidationError(ServiceUtils.invalidTypeErrorMsg(settingName, o, "String")); + } + } + } + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + return (List) optionalField; + } + public static Integer extractRequiredPositiveInteger( Map map, String settingName, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerUtilsTests.java new file mode 100644 index 0000000000000..95b043136cd68 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerUtilsTests.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import com.ibm.icu.text.BreakIterator; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Locale; + +import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT; + +public class ChunkerUtilsTests extends ESTestCase { + public void testCountWords() { + // Test word count matches the whitespace separated word count. + var splitByWhiteSpaceSentenceSizes = sentenceSizes(TEST_TEXT); + + var sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT); + sentenceIterator.setText(TEST_TEXT); + + var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); + wordIterator.setText(TEST_TEXT); + + int start = 0; + int end = sentenceIterator.next(); + assertEquals(splitByWhiteSpaceSentenceSizes[0], ChunkerUtils.countWords(start, end, wordIterator)); + start = end; + end = sentenceIterator.next(); + assertEquals(splitByWhiteSpaceSentenceSizes[1], ChunkerUtils.countWords(start, end, wordIterator)); + start = end; + end = sentenceIterator.next(); + assertEquals(splitByWhiteSpaceSentenceSizes[2], ChunkerUtils.countWords(start, end, wordIterator)); + start = end; + end = sentenceIterator.next(); + assertEquals(splitByWhiteSpaceSentenceSizes[3], ChunkerUtils.countWords(start, end, wordIterator)); + + assertEquals(BreakIterator.DONE, sentenceIterator.next()); + } + + public void testCountWords_short() { + // Test word count matches the whitespace separated word count. + var text = "This is a short sentence. Followed by another."; + + var sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT); + sentenceIterator.setText(text); + + var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); + wordIterator.setText(text); + + int start = 0; + int end = sentenceIterator.next(); + assertEquals(5, ChunkerUtils.countWords(0, end, wordIterator)); + start = end; + end = sentenceIterator.next(); + assertEquals(3, ChunkerUtils.countWords(start, end, wordIterator)); + assertEquals(BreakIterator.DONE, sentenceIterator.next()); + } + + public void testCountWords_WithSymbols() { + { + var text = "foo != bar"; + var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); + wordIterator.setText(text); + // "foo", "bar" - "!=" is not counted + assertEquals(2, ChunkerUtils.countWords(0, text.length(), wordIterator)); + } + { + var text = "foo & bar"; + var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); + wordIterator.setText(text); + // "foo", "bar" - the & is not counted + assertEquals(2, ChunkerUtils.countWords(0, text.length(), wordIterator)); + } + { + var text = "m&s"; + var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); + wordIterator.setText(text); + // "m", "s" - the & is not counted + assertEquals(2, ChunkerUtils.countWords(0, text.length(), wordIterator)); + } + } + + private int[] sentenceSizes(String text) { + var sentences = text.split("\\.\\s+"); + var lengths = new int[sentences.length]; + for (int i = 0; i < sentences.length; i++) { + // strip out the '=' signs as they are not counted as words by ICU + sentences[i] = sentences[i].replace("=", ""); + // split by hyphen or whitespace to match the way + // the ICU break iterator counts words + lengths[i] = sentences[i].split("[ \\-]+").length; + } + return lengths; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java index 9dfa417c3c477..4311b68ec18d3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java @@ -30,6 +30,9 @@ public static ChunkingSettings createRandomChunkingSettings() { case SENTENCE -> { return new SentenceBoundaryChunkingSettings(randomIntBetween(20, 300), randomBoolean() ? 0 : 1); } + case RECURSIVE -> { + return new RecursiveChunkingSettings(randomIntBetween(10, 300), null); + } default -> throw new IllegalArgumentException("Unsupported random strategy [" + randomStrategy + "]"); } } @@ -48,7 +51,12 @@ public static Map createRandomChunkingSettingsMap() { chunkingSettingsMap.put(ChunkingSettingsOptions.OVERLAP.toString(), randomIntBetween(1, maxChunkSize / 2)); } - case SENTENCE -> chunkingSettingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), randomIntBetween(20, 300)); + case SENTENCE -> { + chunkingSettingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), randomIntBetween(20, 300)); + } + case RECURSIVE -> { + chunkingSettingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), randomIntBetween(10, 300)); + } default -> { } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java new file mode 100644 index 0000000000000..0047946b88575 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkerTests.java @@ -0,0 +1,232 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.List; + +public class RecursiveChunkerTests extends ESTestCase { + + private final List TEST_SEPARATORS = List.of("\n", "\f", "\t", "#"); + private final String TEST_SENTENCE = "This is a test sentence that has ten total words. "; + + public void testChunkWithInvalidChunkingSettings() { + RecursiveChunker chunker = new RecursiveChunker(); + ChunkingSettings invalidSettings = new SentenceBoundaryChunkingSettings(randomIntBetween(20, 300), randomIntBetween(0, 1)); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> { + chunker.chunk(randomAlphaOfLength(100), invalidSettings); + }); + + assertEquals("RecursiveChunker can't use ChunkingSettings with strategy [sentence]", exception.getMessage()); + } + + public void testChunkEmptyInput() { + RecursiveChunkingSettings settings = generateChunkingSettings(randomIntBetween(10, 300), generateRandomSeparators()); + assertExpectedChunksGenerated("", settings, List.of(new Chunker.ChunkOffset(0, 0))); + } + + public void testChunkSingleCharacterInput() { + RecursiveChunkingSettings settings = generateChunkingSettings(randomIntBetween(10, 300), generateRandomSeparators()); + assertExpectedChunksGenerated(randomAlphaOfLength(1), settings, List.of(new Chunker.ChunkOffset(0, 1))); + } + + public void testChunkInputShorterThanMaxChunkSize() { + var maxChunkSize = randomIntBetween(10, 300); + var input = randomAlphaOfLength(maxChunkSize - 1); + RecursiveChunkingSettings settings = generateChunkingSettings(maxChunkSize, generateRandomSeparators()); + assertExpectedChunksGenerated(input, settings, List.of(new Chunker.ChunkOffset(0, input.length()))); + } + + public void testChunkInputRequiresOneSplit() { + List separators = generateRandomSeparators(); + RecursiveChunkingSettings settings = generateChunkingSettings(10, separators); + String input = generateTestText(2, List.of(separators.get(0))); + + assertExpectedChunksGenerated( + input, + settings, + List.of(new Chunker.ChunkOffset(0, TEST_SENTENCE.length()), new Chunker.ChunkOffset(TEST_SENTENCE.length(), input.length())) + ); + } + + public void testChunkInputRequiresMultipleSplits() { + var separators = generateRandomSeparators(); + RecursiveChunkingSettings settings = generateChunkingSettings(15, separators); + String input = generateTestText(4, List.of(separators.get(1), separators.get(0), separators.get(1))); + + var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length(); + var expectedSecondChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.get(1).length(); + var expectedThirdChunkOffsetEnd = TEST_SENTENCE.length() * 3 + separators.get(0).length() + separators.get(1).length(); + assertExpectedChunksGenerated( + input, + settings, + List.of( + new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd), + new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, expectedSecondChunkOffsetEnd), + new Chunker.ChunkOffset(expectedSecondChunkOffsetEnd, expectedThirdChunkOffsetEnd), + new Chunker.ChunkOffset(expectedThirdChunkOffsetEnd, input.length()) + ) + ); + } + + public void testChunkInputDoesNotSplitWhenNoLongerExceedingMaxChunkSize() { + var separators = randomSubsetOf(3, TEST_SEPARATORS); + RecursiveChunkingSettings settings = generateChunkingSettings(25, separators); + // Generate a test text such that after each split a valid chunk is found that contains a subsequent separator. This tests that we + // do not continue to split once the chunk size is valid even if there are more separators present in the text. + String input = generateTestText(5, List.of(separators.get(1), separators.get(0), separators.get(2), separators.get(1))); + + var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length() * 2 + separators.get(1).length(); + var expectedSecondChunkOffsetEnd = TEST_SENTENCE.length() * 4 + separators.get(1).length() + separators.get(0).length() + separators + .get(2) + .length(); + assertExpectedChunksGenerated( + input, + settings, + List.of( + new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd), + new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, expectedSecondChunkOffsetEnd), + new Chunker.ChunkOffset(expectedSecondChunkOffsetEnd, input.length()) + ) + ); + } + + public void testChunkInputRequiresBackupChunkingStrategy() { + var separators = generateRandomSeparators(); + RecursiveChunkingSettings settings = generateChunkingSettings(10, separators); + String input = generateTestText(4, List.of("", separators.get(0), "")); + + var expectedFirstChunkOffsetEnd = TEST_SENTENCE.length(); + var expectedSecondChunkOffsetEnd = TEST_SENTENCE.length() * 2; + var expectedThirdChunkOffsetEnd = TEST_SENTENCE.length() * 3 + separators.get(0).length(); + assertExpectedChunksGenerated( + input, + settings, + List.of( + new Chunker.ChunkOffset(0, expectedFirstChunkOffsetEnd), + new Chunker.ChunkOffset(expectedFirstChunkOffsetEnd, expectedSecondChunkOffsetEnd), + new Chunker.ChunkOffset(expectedSecondChunkOffsetEnd, expectedThirdChunkOffsetEnd), + new Chunker.ChunkOffset(expectedThirdChunkOffsetEnd, input.length()) + ) + ); + } + + public void testChunkWithRegexSeparator() { + var separators = List.of("(? separators = generateRandomSeparators(); + List splittersAfterSentences = new ArrayList<>(); + for (int i = 0; i < numSentences - 1; i++) { + splittersAfterSentences.add(randomFrom(separators)); + } + RecursiveChunkingSettings settings = generateChunkingSettings(15, separators); + String input = generateTestText(numSentences, splittersAfterSentences); + + List expectedChunks = new ArrayList<>(); + int currentOffset = 0; + for (int i = 0; i < numSentences; i++) { + int chunkLength = TEST_SENTENCE.length(); + if (i > 0) { + chunkLength += splittersAfterSentences.get(i - 1).length(); + } + expectedChunks.add(new Chunker.ChunkOffset(currentOffset, currentOffset + chunkLength)); + currentOffset += chunkLength; + } + + assertExpectedChunksGenerated(input, settings, expectedChunks); + } + + public void testMarkdownChunking() { + int numSentences = randomIntBetween(10, 50); + List separators = SeparatorSet.MARKDOWN.getSeparators(); + List validHeaders = List.of( + "# Header\n", + "## Header\n", + "### Header\n", + "#### Header\n", + "##### Header\n", + "###### Header\n", + "Header\n-\n", + "Header\n=\n" + ); + List validSplittersAfterSentences = validHeaders.stream().map(header -> "\n" + header).toList(); + List splittersAfterSentences = new ArrayList<>(); + for (int i = 0; i < numSentences - 1; i++) { + splittersAfterSentences.add(randomFrom(validSplittersAfterSentences)); + } + RecursiveChunkingSettings settings = generateChunkingSettings(15, separators); + String input = generateTestText(numSentences, splittersAfterSentences); + String leadingHeader = randomFrom(validHeaders); + input = leadingHeader + input; + + List expectedChunks = new ArrayList<>(); + int currentOffset = 0; + for (int i = 0; i < numSentences; i++) { + int chunkLength = TEST_SENTENCE.length(); + if (i == 0) { + chunkLength += leadingHeader.length(); + } else { + chunkLength += splittersAfterSentences.get(i - 1).length(); + } + expectedChunks.add(new Chunker.ChunkOffset(currentOffset, currentOffset + chunkLength)); + currentOffset += chunkLength; + } + + assertExpectedChunksGenerated(input, settings, expectedChunks); + } + + private void assertExpectedChunksGenerated(String input, RecursiveChunkingSettings settings, List expectedChunks) { + RecursiveChunker chunker = new RecursiveChunker(); + List chunks = chunker.chunk(input, settings); + assertEquals(expectedChunks, chunks); + } + + private String generateTestText(int numSentences, List splittersAfterSentences) { + assert (splittersAfterSentences.size() == numSentences - 1); + + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < numSentences; i++) { + sb.append(TEST_SENTENCE); + if (i < numSentences - 1) { + sb.append(splittersAfterSentences.get(i)); + } + } + return sb.toString(); + } + + private List generateRandomSeparators() { + return randomSubsetOf(randomIntBetween(2, 3), TEST_SEPARATORS); + } + + private RecursiveChunkingSettings generateChunkingSettings(int maxChunkSize, List separators) { + return new RecursiveChunkingSettings(maxChunkSize, separators); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java new file mode 100644 index 0000000000000..40f14e88d2558 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettingsTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class RecursiveChunkingSettingsTests extends AbstractWireSerializingTestCase { + + public void testFromMapValidSettingsWithSeparators() { + var maxChunkSize = randomIntBetween(10, 300); + var separators = randomList(1, 10, () -> randomAlphaOfLength(1)); + Map validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.empty(), Optional.of(separators)); + + RecursiveChunkingSettings settings = RecursiveChunkingSettings.fromMap(validSettings); + + assertEquals(maxChunkSize, settings.getMaxChunkSize()); + assertEquals(separators, settings.getSeparators()); + } + + public void testFromMapValidSettingsWithSeparatorSet() { + var maxChunkSize = randomIntBetween(10, 300); + var separatorSet = randomFrom(SeparatorSet.values()); + Map validSettings = buildChunkingSettingsMap(maxChunkSize, Optional.of(separatorSet.name()), Optional.empty()); + + RecursiveChunkingSettings settings = RecursiveChunkingSettings.fromMap(validSettings); + + assertEquals(maxChunkSize, settings.getMaxChunkSize()); + assertEquals(separatorSet.getSeparators(), settings.getSeparators()); + } + + public void testFromMapMaxChunkSizeTooSmall() { + Map invalidSettings = buildChunkingSettingsMap(randomIntBetween(0, 9), Optional.empty(), Optional.empty()); + + assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); + } + + public void testFromMapMaxChunkSizeTooLarge() { + Map invalidSettings = buildChunkingSettingsMap(randomIntBetween(301, 500), Optional.empty(), Optional.empty()); + + assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); + } + + public void testFromMapInvalidSeparatorSet() { + Map invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.of("invalid"), Optional.empty()); + + assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); + } + + public void testFromMapInvalidSettingKey() { + Map invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.empty(), Optional.empty()); + invalidSettings.put("invalid_key", "invalid_value"); + + assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); + } + + public void testFromMapBothSeparatorsAndSeparatorSet() { + Map invalidSettings = buildChunkingSettingsMap( + randomIntBetween(10, 300), + Optional.of("default"), + Optional.of(List.of("\n\n", "\n")) + ); + + assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); + } + + public void testFromMapEmptySeparators() { + Map invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.empty(), Optional.of(List.of())); + + assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings)); + } + + private Map buildChunkingSettingsMap( + int maxChunkSize, + Optional separatorSet, + Optional> separators + ) { + Map settingsMap = new HashMap<>(); + settingsMap.put(ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.RECURSIVE.toString()); + settingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); + separatorSet.ifPresent(s -> settingsMap.put(ChunkingSettingsOptions.SEPARATOR_SET.toString(), s)); + separators.ifPresent(strings -> settingsMap.put(ChunkingSettingsOptions.SEPARATORS.toString(), strings)); + return settingsMap; + } + + @Override + protected Writeable.Reader instanceReader() { + return RecursiveChunkingSettings::new; + } + + @Override + protected RecursiveChunkingSettings createTestInstance() { + int maxChunkSize = randomIntBetween(10, 300); + int numSeparators = randomIntBetween(1, 10); + List separators = new ArrayList<>(); + for (int i = 0; i < numSeparators; i++) { + separators.add(randomAlphaOfLength(1)); + } + + return new RecursiveChunkingSettings(maxChunkSize, separators); + } + + @Override + protected RecursiveChunkingSettings mutateInstance(RecursiveChunkingSettings instance) throws IOException { + int maxChunkSize = randomValueOtherThan(instance.getMaxChunkSize(), () -> randomIntBetween(10, 300)); + List separators = instance.getSeparators(); + separators.add(randomAlphaOfLength(1)); + return new RecursiveChunkingSettings(maxChunkSize, separators); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java index f81894ccd4bbb..841c9188bc6b1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java @@ -257,32 +257,6 @@ public void testShortLongShortSentences_WithOverlap() { assertTrue(chunks.get(4).trim().endsWith(".")); // full sentence(s) } - public void testCountWords() { - // Test word count matches the whitespace separated word count. - var splitByWhiteSpaceSentenceSizes = sentenceSizes(TEST_TEXT); - - var sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT); - sentenceIterator.setText(TEST_TEXT); - - var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); - wordIterator.setText(TEST_TEXT); - - int start = 0; - int end = sentenceIterator.next(); - assertEquals(splitByWhiteSpaceSentenceSizes[0], SentenceBoundaryChunker.countWords(start, end, wordIterator)); - start = end; - end = sentenceIterator.next(); - assertEquals(splitByWhiteSpaceSentenceSizes[1], SentenceBoundaryChunker.countWords(start, end, wordIterator)); - start = end; - end = sentenceIterator.next(); - assertEquals(splitByWhiteSpaceSentenceSizes[2], SentenceBoundaryChunker.countWords(start, end, wordIterator)); - start = end; - end = sentenceIterator.next(); - assertEquals(splitByWhiteSpaceSentenceSizes[3], SentenceBoundaryChunker.countWords(start, end, wordIterator)); - - assertEquals(BreakIterator.DONE, sentenceIterator.next()); - } - public void testSkipWords() { int numWords = 50; StringBuilder sb = new StringBuilder(); @@ -307,49 +281,6 @@ public void testSkipWords() { assertThat(pos, greaterThan(0)); } - public void testCountWords_short() { - // Test word count matches the whitespace separated word count. - var text = "This is a short sentence. Followed by another."; - - var sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT); - sentenceIterator.setText(text); - - var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); - wordIterator.setText(text); - - int start = 0; - int end = sentenceIterator.next(); - assertEquals(5, SentenceBoundaryChunker.countWords(0, end, wordIterator)); - start = end; - end = sentenceIterator.next(); - assertEquals(3, SentenceBoundaryChunker.countWords(start, end, wordIterator)); - assertEquals(BreakIterator.DONE, sentenceIterator.next()); - } - - public void testCountWords_WithSymbols() { - { - var text = "foo != bar"; - var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); - wordIterator.setText(text); - // "foo", "bar" - "!=" is not counted - assertEquals(2, SentenceBoundaryChunker.countWords(0, text.length(), wordIterator)); - } - { - var text = "foo & bar"; - var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); - wordIterator.setText(text); - // "foo", "bar" - the & is not counted - assertEquals(2, SentenceBoundaryChunker.countWords(0, text.length(), wordIterator)); - } - { - var text = "m&s"; - var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); - wordIterator.setText(text); - // "m", "s" - the & is not counted - assertEquals(2, SentenceBoundaryChunker.countWords(0, text.length(), wordIterator)); - } - } - public void testChunkSplitLargeChunkSizesWithChunkingSettings() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 8ffdca90ef2d5..d90cb638709db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -27,6 +27,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalList; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfStringTuples; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; @@ -479,6 +480,61 @@ public void testExtractOptionalString_AddsException_WhenFieldIsEmpty() { assertThat(validation.validationErrors().get(0), is("[scope] Invalid value empty string. [key] must be a non-empty string")); } + public void testExtractOptionalList_CreatesList() { + var validation = new ValidationException(); + var list = List.of(randomAlphaOfLength(10), randomAlphaOfLength(10)); + + Map map = modifiableMap(Map.of("key", list)); + assertEquals(list, extractOptionalList(map, "key", String.class, validation)); + assertTrue(validation.validationErrors().isEmpty()); + assertTrue(map.isEmpty()); + } + + public void testExtractOptionalList_AddsException_WhenFieldDoesNotExist() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", List.of(randomAlphaOfLength(10), randomAlphaOfLength(10)))); + assertNull(extractOptionalList(map, "abc", String.class, validation)); + assertThat(validation.validationErrors(), hasSize(1)); + assertThat(map.size(), is(1)); + } + + public void testExtractOptionalList_AddsException_WhenFieldIsEmpty() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", "")); + assertNull(extractOptionalList(map, "key", String.class, validation)); + assertFalse(validation.validationErrors().isEmpty()); + assertTrue(map.isEmpty()); + } + + public void testExtractOptionalList_AddsException_WhenFieldIsNotAList() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", 1)); + assertNull(extractOptionalList(map, "key", String.class, validation)); + assertFalse(validation.validationErrors().isEmpty()); + assertTrue(map.isEmpty()); + } + + public void testExtractOptionalList_AddsException_WhenFieldIsNotAListOfTheCorrectType() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", List.of(1, 2))); + assertNull(extractOptionalList(map, "key", String.class, validation)); + assertFalse(validation.validationErrors().isEmpty()); + assertTrue(map.isEmpty()); + } + + public void testExtractOptionalList_AddsException_WhenFieldContainsMixedTypeValues() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", List.of(1, "a"))); + assertNull(extractOptionalList(map, "key", String.class, validation)); + assertFalse(validation.validationErrors().isEmpty()); + assertTrue(map.isEmpty()); + } + public void testExtractOptionalPositiveInt() { var validation = new ValidationException(); validation.addValidationError("previous error");