Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126866.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126866
summary: Add recursive chunker
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

public enum ChunkingStrategy {
WORD("word"),
SENTENCE("sentence");
SENTENCE("sentence"),
RECURSIVE("recursive");

private final String chunkingStrategy;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.chunking.RecursiveChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
Expand Down Expand Up @@ -472,6 +473,9 @@ private static void addChunkingSettingsNamedWriteables(List<NamedWriteableRegist
SentenceBoundaryChunkingSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ChunkingSettings.class, RecursiveChunkingSettings.NAME, RecursiveChunkingSettings::new)
);
}

private static void addInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static Chunker fromChunkingStrategy(ChunkingStrategy chunkingStrategy) {
return switch (chunkingStrategy) {
case WORD -> new WordBoundaryChunker();
case SENTENCE -> new SentenceBoundaryChunker();
case RECURSIVE -> new RecursiveChunker();
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import com.ibm.icu.text.BreakIterator;

public class ChunkerUtils {

// setText() should be applied before using this function.
static int countWords(int start, int end, BreakIterator wordIterator) {
assert start < end;
wordIterator.preceding(start); // start of the current word

int boundary = wordIterator.current();
int wordCount = 0;
while (boundary != BreakIterator.DONE && boundary <= end) {
int wordStatus = wordIterator.getRuleStatus();
if (wordStatus != BreakIterator.WORD_NONE) {
wordCount++;
}
boundary = wordIterator.next();
}

return wordCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public static ChunkingSettings fromMap(Map<String, Object> settings, boolean ret
return switch (chunkingStrategy) {
case WORD -> WordBoundaryChunkingSettings.fromMap(new HashMap<>(settings));
case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(new HashMap<>(settings));
case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings));
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ public enum ChunkingSettingsOptions {
STRATEGY("strategy"),
MAX_CHUNK_SIZE("max_chunk_size"),
OVERLAP("overlap"),
SENTENCE_OVERLAP("sentence_overlap");
SENTENCE_OVERLAP("sentence_overlap"),
SEPARATOR_SET("separator_set"),
SEPARATORS("separators");

private final String chunkingSettingsOption;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import com.ibm.icu.text.BreakIterator;

import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.ChunkingSettings;

import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;

public class RecursiveChunker implements Chunker {
private final BreakIterator wordIterator;

public RecursiveChunker() {
wordIterator = BreakIterator.getWordInstance();
}

@Override
public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings) {
if (chunkingSettings instanceof RecursiveChunkingSettings recursiveChunkingSettings) {
return chunk(input, recursiveChunkingSettings.getSeparators(), recursiveChunkingSettings.getMaxChunkSize(), 0, 0);
} else {
throw new IllegalArgumentException(
Strings.format("RecursiveChunker can't use ChunkingSettings with strategy [%s]", chunkingSettings.getChunkingStrategy())
);
}
}

private List<ChunkOffset> chunk(String input, List<String> separators, int maxChunkSize, int separatorIndex, int chunkOffset) {
if (input.length() < 2 || isChunkWithinMaxSize(input, new ChunkOffset(0, input.length()), maxChunkSize)) {
return List.of(new ChunkOffset(chunkOffset, chunkOffset + input.length()));
}

if (separatorIndex > separators.size() - 1) {
return chunkWithBackupChunker(input, maxChunkSize, chunkOffset);
}

var potentialChunks = mergeChunkOffsetsUpToMaxChunkSize(
input,
splitTextBySeparatorRegex(input, separators.get(separatorIndex)),
maxChunkSize
);
var actualChunks = new ArrayList<ChunkOffset>();
for (var potentialChunk : potentialChunks) {
if (isChunkWithinMaxSize(input, potentialChunk, maxChunkSize)) {
actualChunks.add(new ChunkOffset(chunkOffset + potentialChunk.start(), chunkOffset + potentialChunk.end()));
} else {
actualChunks.addAll(
chunk(
input.substring(potentialChunk.start(), potentialChunk.end()),
separators,
maxChunkSize,
separatorIndex + 1,
chunkOffset + potentialChunk.start()
)
);
}
}

return actualChunks;
}

private boolean isChunkWithinMaxSize(String fullText, ChunkOffset chunk, int maxChunkSize) {
wordIterator.setText(fullText);
return ChunkerUtils.countWords(chunk.start(), chunk.end(), wordIterator) <= maxChunkSize;
}

private List<ChunkOffset> splitTextBySeparatorRegex(String input, String separatorRegex) {
var pattern = Pattern.compile(separatorRegex);
var matcher = pattern.matcher(input);

var chunkOffsets = new ArrayList<ChunkOffset>();
int chunkStart = 0;
int searchStart = 0;
while (matcher.find(searchStart)) {
var chunkEnd = matcher.start();
if (chunkStart < chunkEnd) {
chunkOffsets.add(new ChunkOffset(chunkStart, chunkEnd));
}
chunkStart = matcher.start();
searchStart = matcher.end();
}

if (chunkStart < input.length()) {
chunkOffsets.add(new ChunkOffset(chunkStart, input.length()));
}

return chunkOffsets;
}

private List<ChunkOffset> mergeChunkOffsetsUpToMaxChunkSize(String input, List<ChunkOffset> chunkOffsets, int maxChunkSize) {
if (chunkOffsets.size() < 2) {
return chunkOffsets;
}

List<ChunkOffset> mergedOffsets = new ArrayList<>();
var mergedChunk = chunkOffsets.getFirst();
for (int i = 1; i < chunkOffsets.size(); i++) {
var potentialMergedChunk = new ChunkOffset(mergedChunk.start(), chunkOffsets.get(i).end());
if (isChunkWithinMaxSize(input, potentialMergedChunk, maxChunkSize)) {
mergedChunk = potentialMergedChunk;
} else {
mergedOffsets.add(mergedChunk);
mergedChunk = chunkOffsets.get(i);
}

if (i == chunkOffsets.size() - 1) {
mergedOffsets.add(mergedChunk);
}
}
return mergedOffsets;
}

private List<ChunkOffset> chunkWithBackupChunker(String input, int maxChunkSize, int chunkOffset) {
var chunks = new SentenceBoundaryChunker().chunk(input, new SentenceBoundaryChunkingSettings(maxChunkSize, 0));
var chunksWithOffsets = new ArrayList<ChunkOffset>();
for (var chunk : chunks) {
chunksWithOffsets.add(new ChunkOffset(chunk.start() + chunkOffset, chunk.end() + chunkOffset));
}
return chunksWithOffsets;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

import java.io.IOException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class RecursiveChunkingSettings implements ChunkingSettings {
public static final String NAME = "RecursiveChunkingSettings";
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.RECURSIVE;
private static final int MAX_CHUNK_SIZE_LOWER_LIMIT = 10;
private static final int MAX_CHUNK_SIZE_UPPER_LIMIT = 300;

private static final Set<String> VALID_KEYS = Set.of(
ChunkingSettingsOptions.STRATEGY.toString(),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
ChunkingSettingsOptions.SEPARATOR_SET.toString(),
ChunkingSettingsOptions.SEPARATORS.toString()
);

private final int maxChunkSize;
private final List<String> separators;

public RecursiveChunkingSettings(int maxChunkSize, List<String> separators) {
this.maxChunkSize = maxChunkSize;
this.separators = separators == null ? SeparatorSet.PLAINTEXT.getSeparators() : separators;
}

public RecursiveChunkingSettings(StreamInput in) throws IOException {
maxChunkSize = in.readInt();
separators = in.readCollectionAsList(StreamInput::readString);
}

public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray();
if (invalidSettings.length > 0) {
validationException.addValidationError(
Strings.format("Recursive chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings))
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerBetween(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
MAX_CHUNK_SIZE_UPPER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);

SeparatorSet separatorSet = ServiceUtils.extractOptionalEnum(
map,
ChunkingSettingsOptions.SEPARATOR_SET.toString(),
ModelConfigurations.CHUNKING_SETTINGS,
SeparatorSet::fromString,
EnumSet.allOf(SeparatorSet.class),
validationException
);

List<String> separators = ServiceUtils.extractOptionalList(
map,
ChunkingSettingsOptions.SEPARATORS.toString(),
String.class,
validationException
);

if (separators != null && separatorSet != null) {
validationException.addValidationError("Recursive chunking settings can not have both separators and separator_set");
}

if (separatorSet != null) {
separators = separatorSet.getSeparators();
} else if (separators != null && separators.isEmpty()) {
validationException.addValidationError("Recursive chunking settings can not have an empty list of separators");
}

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new RecursiveChunkingSettings(maxChunkSize, separators);
}

public int getMaxChunkSize() {
return maxChunkSize;
}

public List<String> getSeparators() {
return separators;
}

@Override
public ChunkingStrategy getChunkingStrategy() {
return STRATEGY;
}

@Override
public Map<String, Object> asMap() {
return Map.of(
ChunkingSettingsOptions.STRATEGY.toString(),
STRATEGY.toString().toLowerCase(Locale.ROOT),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
maxChunkSize,
ChunkingSettingsOptions.SEPARATORS.toString(),
separators
);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return null; // TODO: Add transport version
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(maxChunkSize);
out.writeCollection(separators, StreamOutput::writeString);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
{
builder.field(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY);
builder.field(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize);
builder.field(ChunkingSettingsOptions.SEPARATORS.toString(), separators);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RecursiveChunkingSettings that = (RecursiveChunkingSettings) o;
return Objects.equals(maxChunkSize, that.maxChunkSize) && Objects.equals(separators, that.separators);
}

@Override
public int hashCode() {
return Objects.hash(maxChunkSize, separators);
}
}
Loading
Loading