Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126866.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126866
summary: Add recursive chunker
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
public enum ChunkingStrategy {
WORD("word"),
SENTENCE("sentence"),
RECURSIVE("recursive"),
NONE("none");

private final String chunkingStrategy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.chunking.NoneChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.RecursiveChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
Expand Down Expand Up @@ -567,6 +568,9 @@ private static void addChunkingSettingsNamedWriteables(List<NamedWriteableRegist
SentenceBoundaryChunkingSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ChunkingSettings.class, RecursiveChunkingSettings.NAME, RecursiveChunkingSettings::new)
);
}

private static void addInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public static Chunker fromChunkingStrategy(ChunkingStrategy chunkingStrategy) {
case NONE -> NoopChunker.INSTANCE;
case WORD -> new WordBoundaryChunker();
case SENTENCE -> new SentenceBoundaryChunker();
case RECURSIVE -> new RecursiveChunker();
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import com.ibm.icu.text.BreakIterator;

public class ChunkerUtils {

// setText() should be applied before using this function.
static int countWords(int start, int end, BreakIterator wordIterator) {
assert start < end;
wordIterator.preceding(start); // start of the current word

int boundary = wordIterator.current();
int wordCount = 0;
while (boundary != BreakIterator.DONE && boundary <= end) {
int wordStatus = wordIterator.getRuleStatus();
if (wordStatus != BreakIterator.WORD_NONE) {
wordCount++;
}
boundary = wordIterator.next();
}

return wordCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public static ChunkingSettings fromMap(Map<String, Object> settings, boolean ret
case NONE -> NoneChunkingSettings.INSTANCE;
case WORD -> WordBoundaryChunkingSettings.fromMap(new HashMap<>(settings));
case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(new HashMap<>(settings));
case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings));
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ public enum ChunkingSettingsOptions {
STRATEGY("strategy"),
MAX_CHUNK_SIZE("max_chunk_size"),
OVERLAP("overlap"),
SENTENCE_OVERLAP("sentence_overlap");
SENTENCE_OVERLAP("sentence_overlap"),
SEPARATOR_SET("separator_set"),
SEPARATORS("separators");

private final String chunkingSettingsOption;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import com.ibm.icu.text.BreakIterator;

import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.ChunkingSettings;

import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;

/**
* Split text into chunks recursively based on a list of separator regex strings.
* The maximum chunk size is measured in words and controlled
* by {@code maxNumberWordsPerChunk}. For each separator the chunker will go through the following process:
* 1. Split the text on each regex match of the separator.
* 2. For each chunk after the merge:
* 1. Return it if it is within the maximum chunk size.
* 2. Repeat the process using the next separator in the list if the chunk exceeds the maximum chunk size.
* If there are no more separators left to try, run the {@code SentenceBoundaryChunker} with the provided
* max chunk size and no overlaps.
*/
public class RecursiveChunker implements Chunker {
private final BreakIterator wordIterator;

public RecursiveChunker() {
wordIterator = BreakIterator.getWordInstance();
}

@Override
public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
if (chunkingSettings instanceof RecursiveChunkingSettings recursiveChunkingSettings) {
return chunk(
input,
new ChunkOffset(0, input.length()),
recursiveChunkingSettings.getSeparators(),
recursiveChunkingSettings.getMaxChunkSize(),
0
);
} else {
throw new IllegalArgumentException(
Strings.format("RecursiveChunker can't use ChunkingSettings with strategy [%s]", chunkingSettings.getChunkingStrategy())
);
}
}

private List<ChunkOffset> chunk(String input, ChunkOffset offset, List<String> separators, int maxChunkSize, int separatorIndex) {
if (offset.start() == offset.end() || isChunkWithinMaxSize(buildChunkOffsetAndCount(input, offset), maxChunkSize)) {
return List.of(offset);
}

if (separatorIndex > separators.size() - 1) {
return chunkWithBackupChunker(input, offset, maxChunkSize);
}

var potentialChunks = splitTextBySeparatorRegex(input, offset, separators.get(separatorIndex));
var actualChunks = new ArrayList<ChunkOffset>();
for (var potentialChunk : potentialChunks) {
if (isChunkWithinMaxSize(potentialChunk, maxChunkSize)) {
actualChunks.add(potentialChunk.chunkOffset());
} else {
actualChunks.addAll(chunk(input, potentialChunk.chunkOffset(), separators, maxChunkSize, separatorIndex + 1));
}
}

return actualChunks;
}

private boolean isChunkWithinMaxSize(ChunkOffsetAndCount chunkOffsetAndCount, int maxChunkSize) {
return chunkOffsetAndCount.wordCount <= maxChunkSize;
}

private ChunkOffsetAndCount buildChunkOffsetAndCount(String fullText, ChunkOffset offset) {
wordIterator.setText(fullText);
return new ChunkOffsetAndCount(offset, ChunkerUtils.countWords(offset.start(), offset.end(), wordIterator));
}

private List<ChunkOffsetAndCount> splitTextBySeparatorRegex(String input, ChunkOffset offset, String separatorRegex) {
var pattern = Pattern.compile(separatorRegex, Pattern.MULTILINE);
var matcher = pattern.matcher(input).region(offset.start(), offset.end());

var chunkOffsets = new ArrayList<ChunkOffsetAndCount>();
int chunkStart = offset.start();
while (matcher.find()) {
var chunkEnd = matcher.start();

if (chunkStart < chunkEnd) {
chunkOffsets.add(buildChunkOffsetAndCount(input, new ChunkOffset(chunkStart, chunkEnd)));
}
chunkStart = chunkEnd;
}

if (chunkStart < offset.end()) {
chunkOffsets.add(buildChunkOffsetAndCount(input, new ChunkOffset(chunkStart, offset.end())));
}

return chunkOffsets;
}

private List<ChunkOffset> chunkWithBackupChunker(String input, ChunkOffset offset, int maxChunkSize) {
var chunks = new SentenceBoundaryChunker().chunk(
input.substring(offset.start(), offset.end()),
new SentenceBoundaryChunkingSettings(maxChunkSize, 0)
);
var chunksWithOffsets = new ArrayList<ChunkOffset>();
for (var chunk : chunks) {
chunksWithOffsets.add(new ChunkOffset(chunk.start() + offset.start(), chunk.end() + offset.start()));
}
return chunksWithOffsets;
}

private record ChunkOffsetAndCount(ChunkOffset chunkOffset, int wordCount) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

import java.io.IOException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class RecursiveChunkingSettings implements ChunkingSettings {
public static final String NAME = "RecursiveChunkingSettings";
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.RECURSIVE;
private static final int MAX_CHUNK_SIZE_LOWER_LIMIT = 10;
private static final int MAX_CHUNK_SIZE_UPPER_LIMIT = 300;

private static final Set<String> VALID_KEYS = Set.of(
ChunkingSettingsOptions.STRATEGY.toString(),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
ChunkingSettingsOptions.SEPARATOR_SET.toString(),
ChunkingSettingsOptions.SEPARATORS.toString()
);

private final int maxChunkSize;
private final List<String> separators;

public RecursiveChunkingSettings(int maxChunkSize, List<String> separators) {
this.maxChunkSize = maxChunkSize;
this.separators = separators == null ? SeparatorSet.PLAINTEXT.getSeparators() : separators;
}

public RecursiveChunkingSettings(StreamInput in) throws IOException {
maxChunkSize = in.readInt();
separators = in.readCollectionAsList(StreamInput::readString);
}

public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray();
if (invalidSettings.length > 0) {
validationException.addValidationError(
Strings.format("Recursive chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings))
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerBetween(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
MAX_CHUNK_SIZE_UPPER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);

SeparatorSet separatorSet = ServiceUtils.extractOptionalEnum(
map,
ChunkingSettingsOptions.SEPARATOR_SET.toString(),
ModelConfigurations.CHUNKING_SETTINGS,
SeparatorSet::fromString,
EnumSet.allOf(SeparatorSet.class),
validationException
);

List<String> separators = ServiceUtils.extractOptionalList(
map,
ChunkingSettingsOptions.SEPARATORS.toString(),
String.class,
validationException
);

if (separators != null && separatorSet != null) {
validationException.addValidationError("Recursive chunking settings can not have both separators and separator_set");
}

if (separatorSet != null) {
separators = separatorSet.getSeparators();
} else if (separators != null && separators.isEmpty()) {
validationException.addValidationError("Recursive chunking settings can not have an empty list of separators");
}

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new RecursiveChunkingSettings(maxChunkSize, separators);
}

public int getMaxChunkSize() {
return maxChunkSize;
}

public List<String> getSeparators() {
return separators;
}

@Override
public ChunkingStrategy getChunkingStrategy() {
return STRATEGY;
}

@Override
public Map<String, Object> asMap() {
return Map.of(
ChunkingSettingsOptions.STRATEGY.toString(),
STRATEGY.toString().toLowerCase(Locale.ROOT),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
maxChunkSize,
ChunkingSettingsOptions.SEPARATORS.toString(),
separators
);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return null; // TODO: Add transport version
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(maxChunkSize);
out.writeCollection(separators, StreamOutput::writeString);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
{
builder.field(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY);
builder.field(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize);
builder.field(ChunkingSettingsOptions.SEPARATORS.toString(), separators);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RecursiveChunkingSettings that = (RecursiveChunkingSettings) o;
return Objects.equals(maxChunkSize, that.maxChunkSize) && Objects.equals(separators, that.separators);
}

@Override
public int hashCode() {
return Objects.hash(maxChunkSize, separators);
}
}
Loading
Loading