Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
849147e
Add RerankRequestChunker
dan-rubinstein Jun 10, 2025
c41d54c
Merge branch 'main' into rerank-request-chunker
elasticmachine Jul 3, 2025
da4c939
Add chunking strategy generation
dan-rubinstein Jul 4, 2025
004ca8f
Merge branch 'main' into rerank-request-chunker
davidkyle Jul 18, 2025
5ec620a
Merge branch 'main' into rerank-request-chunker
elasticmachine Jul 30, 2025
4ff8eb0
Adding unit tests and fixing token/word ratio
dan-rubinstein Jul 23, 2025
ec78b87
Merge branch 'main' into rerank-request-chunker
elasticmachine Aug 13, 2025
9ef8917
Add configurable values for long document handling strategy and maxim…
dan-rubinstein Sep 8, 2025
24497ae
Adding back sentence overlap for rerank chunking strategy
dan-rubinstein Sep 11, 2025
1fea365
Merge branch 'main' into rerank-request-chunker
elasticmachine Sep 11, 2025
8396214
Merge branch 'main' into rerank-request-chunker
elasticmachine Sep 22, 2025
8b97711
Adding unit tests, transport version, and feature flag
dan-rubinstein Sep 18, 2025
833ef02
Update docs/changelog/130485.yaml
dan-rubinstein Sep 22, 2025
77701e1
Merge branch 'main' of github.com:elastic/elasticsearch into rerank-r…
dan-rubinstein Sep 25, 2025
344e121
Adding unit tests and refactoring code with clearer naming conventions
dan-rubinstein Sep 25, 2025
02c9d0a
Merge branch 'main' of github.com:elastic/elasticsearch into rerank-r…
dan-rubinstein Sep 29, 2025
d68bf09
Merge branch 'main' of github.com:elastic/elasticsearch into rerank-r…
dan-rubinstein Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ protected void doInference(
InferenceService service,
ActionListener<InferenceServiceResults> listener
) {

service.infer(
model,
request.getQuery(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ public class ChunkingSettingsBuilder {
public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1);
// Old settings used for backward compatibility for endpoints created before 8.16 when default was changed
public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
public static final int ELASTIC_RERANKER_TOKEN_LIMIT = 512;
public static final int ELASTIC_RERANKER_EXTRA_TOKEN_COUNT = 3;
public static final float WORDS_PER_TOKEN = 0.75f;

public static ChunkingSettings fromMap(Map<String, Object> settings) {
return fromMap(settings, true);
Expand Down Expand Up @@ -51,4 +54,17 @@ public static ChunkingSettings fromMap(Map<String, Object> settings, boolean ret
case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings));
};
}

public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWordCount) {
var queryTokenCount = Math.ceil(queryWordCount / WORDS_PER_TOKEN);
var chunkSizeTokenCountWithFullQuery = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount);

var maxChunkSizeTokenCount = Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2);
if (chunkSizeTokenCountWithFullQuery > maxChunkSizeTokenCount) {
maxChunkSizeTokenCount = chunkSizeTokenCountWithFullQuery;
}

var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN);
return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import com.ibm.icu.text.BreakIterator;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class RerankRequestChunker {
private final List<String> inputs;
private final List<RerankChunks> rerankChunks;

public RerankRequestChunker(String query, List<String> inputs, Integer maxChunksPerDoc) {
this.inputs = inputs;
this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query), maxChunksPerDoc);
}

private List<RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings, Integer maxChunksPerDoc) {
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
var chunks = new ArrayList<RerankChunks>();
for (int i = 0; i < inputs.size(); i++) {
var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings);
if (maxChunksPerDoc != null && chunksForInput.size() > maxChunksPerDoc) {
var limitedChunks = chunksForInput.subList(0, maxChunksPerDoc - 1);
var lastChunk = limitedChunks.getLast();
limitedChunks.add(new Chunker.ChunkOffset(lastChunk.end(), inputs.get(i).length()));
chunksForInput = limitedChunks;
}

for (var chunk : chunksForInput) {
chunks.add(new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end())));
}
}
return chunks;
}

public List<String> getChunkedInputs() {
List<String> chunkedInputs = new ArrayList<>();
for (RerankChunks chunk : rerankChunks) {
chunkedInputs.add(chunk.chunkString());
}

// TODO: Score the inputs here and only return the top N chunks for each document
return chunkedInputs;
}

public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(ActionListener<InferenceServiceResults> listener) {
return ActionListener.wrap(results -> {
if (results instanceof RankedDocsResults rankedDocsResults) {
listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults));

} else {
listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass()));
}

}, listener::onFailure);
}

// TODO: Can we assume the rankeddocsresults are always sorted by relevance score?
// TODO: Should we short circuit if no chunking was done?
private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) {
List<RankedDocsResults.RankedDoc> updatedRankedDocs = new ArrayList<>();
Set<Integer> docIndicesSeen = new HashSet<>();
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be safe and ensure the highest scoring chunk is used rankedDocsResults should be sorted. The results almost certainly will be sorted but just in case.

The sorting could be done in the RankedDocsResults constructor

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, good catch, I added the sort at the end of this function but it should be in the construction to cover cases when the results aren't sorted but it should be in the rankedDocsResults.getRankedDocs() call to ensure we are taking the top result for each doc. I'll update this to sort the ranked docs before looping and will also update the updatedRankedDocs to be topRankedDocs as I think that's a bit clearer on what we're trying to store.

int chunkIndex = rankedDoc.index();
int docIndex = rerankChunks.get(chunkIndex).docIndex();

if (docIndicesSeen.contains(docIndex) == false) {
// Create a ranked doc with the full input string and the index for the document instead of the chunk
RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc(
docIndex,
rankedDoc.relevanceScore(),
inputs.get(docIndex)
);
updatedRankedDocs.add(updatedRankedDoc);
docIndicesSeen.add(docIndex);
}
}

return new RankedDocsResults(updatedRankedDocs);
}

public record RerankChunks(int docIndex, String chunkString) {};

private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) {
var wordIterator = BreakIterator.getWordInstance();
wordIterator.setText(query);
var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator);
return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rer
}

protected InferenceAction.Request generateRequest(List<String> docFeatures) {
// TODO: Try running the RerankRequestChunker here.
return new InferenceAction.Request(
TaskType.RERANK,
inferenceId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,49 @@

import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;

import java.io.IOException;
import java.util.EnumSet;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID;

public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings {

public static final String NAME = "elastic_reranker_service_settings";

private static final String LONG_DOCUMENT_HANDLING_STRATEGY = "long_document_handling_strategy";
private static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc";

private final LongDocumentHandlingStrategy longDocumentHandlingStrategy;
private final Integer maxChunksPerDoc;

public static ElasticRerankerServiceSettings defaultEndpointSettings() {
return new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32));
}

public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) {
super(other);
this.longDocumentHandlingStrategy = null;
this.maxChunksPerDoc = null;
}

public ElasticRerankerServiceSettings(
ElasticsearchInternalServiceSettings other,
LongDocumentHandlingStrategy longDocumentHandlingStrategy,
Integer maxChunksPerDoc
) {
super(other);
this.longDocumentHandlingStrategy = longDocumentHandlingStrategy;
this.maxChunksPerDoc = maxChunksPerDoc;

}

private ElasticRerankerServiceSettings(
Expand All @@ -35,10 +61,15 @@ private ElasticRerankerServiceSettings(
AdaptiveAllocationsSettings adaptiveAllocationsSettings
) {
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null);
this.longDocumentHandlingStrategy = null;
this.maxChunksPerDoc = null;
}

public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
super(in);
// TODO: Add transport version here
this.longDocumentHandlingStrategy = in.readOptionalEnum(LongDocumentHandlingStrategy.class);
this.maxChunksPerDoc = in.readOptionalInt();
}

/**
Expand All @@ -48,21 +79,89 @@ public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
* {@link ValidationException} is thrown.
*
* @param map Source map containing the config
* @return The builder
* @return Parsed and validated service settings
*/
public static Builder fromRequestMap(Map<String, Object> map) {
public static ElasticRerankerServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();
var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);

LongDocumentHandlingStrategy longDocumentHandlingStrategy = extractOptionalEnum(
map,
LONG_DOCUMENT_HANDLING_STRATEGY,
ModelConfigurations.SERVICE_SETTINGS,
LongDocumentHandlingStrategy::fromString,
EnumSet.allOf(LongDocumentHandlingStrategy.class),
validationException
);

Integer maxChunksPerDoc = extractOptionalPositiveInteger(
map,
MAX_CHUNKS_PER_DOC,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);

if (maxChunksPerDoc != null
&& (longDocumentHandlingStrategy == null || longDocumentHandlingStrategy == LongDocumentHandlingStrategy.TRUNCATE)) {
validationException.addValidationError(
"The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_HANDLING_STRATEGY + "] to be set to [chunk]"
);
}

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return baseSettings;
return new ElasticRerankerServiceSettings(baseSettings.build(), longDocumentHandlingStrategy, maxChunksPerDoc);
}

public LongDocumentHandlingStrategy getLongDocumentHandlingStrategy() {
return longDocumentHandlingStrategy;
}

public Integer getMaxChunksPerDoc() {
return maxChunksPerDoc;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
// TODO: Add transport version here
out.writeOptionalEnum(longDocumentHandlingStrategy);
out.writeOptionalInt(maxChunksPerDoc);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
addInternalSettingsToXContent(builder, params);
if (longDocumentHandlingStrategy != null) {
builder.field(LONG_DOCUMENT_HANDLING_STRATEGY, longDocumentHandlingStrategy.strategyName);
}
if (maxChunksPerDoc != null) {
builder.field(MAX_CHUNKS_PER_DOC, maxChunksPerDoc);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return ElasticRerankerServiceSettings.NAME;
}

public enum LongDocumentHandlingStrategy {
CHUNK("chunk"),
TRUNCATE("truncate");

public final String strategyName;

LongDocumentHandlingStrategy(String strategyName) {
this.strategyName = strategyName;
}

public static LongDocumentHandlingStrategy fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

Expand Down Expand Up @@ -349,19 +350,13 @@ private void rerankerCase(
ActionListener<Model> modelListener
) {

var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap);
var serviceSettings = ElasticRerankerServiceSettings.fromMap(serviceSettingsMap);

throwIfNotEmptyMap(config, name());
throwIfNotEmptyMap(serviceSettingsMap, name());

modelListener.onResponse(
new ElasticRerankerModel(
inferenceEntityId,
taskType,
NAME,
new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()),
RerankTaskSettings.fromMap(taskSettingsMap)
)
new ElasticRerankerModel(inferenceEntityId, taskType, NAME, serviceSettings, RerankTaskSettings.fromMap(taskSettingsMap))
);
}

Expand Down Expand Up @@ -535,7 +530,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
inferenceEntityId,
taskType,
NAME,
new ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)),
ElasticRerankerServiceSettings.fromMap(serviceSettingsMap),
RerankTaskSettings.fromMap(taskSettingsMap)
);
} else {
Expand Down Expand Up @@ -688,7 +683,25 @@ public void inferRerank(
Map<String, Object> requestTaskSettings,
ActionListener<InferenceServiceResults> listener
) {
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
var chunkedInputs = inputs;
var resultsListener = listener;
if (model instanceof ElasticRerankerModel elasticRerankerModel) {
var serviceSettings = elasticRerankerModel.getServiceSettings();
var longDocumentHandlingStrategy = serviceSettings.getLongDocumentHandlingStrategy();
if (longDocumentHandlingStrategy == ElasticRerankerServiceSettings.LongDocumentHandlingStrategy.CHUNK) {
var rerankChunker = new RerankRequestChunker(query, inputs, serviceSettings.getMaxChunksPerDoc());
chunkedInputs = rerankChunker.getChunkedInputs();
resultsListener = rerankChunker.parseChunkedRerankResultsListener(listener);
}

}
var request = buildInferenceRequest(
model.mlNodeDeploymentId(),
new TextSimilarityConfigUpdate(query),
chunkedInputs,
inputType,
timeout
);

var returnDocs = Boolean.TRUE;
if (returnDocuments != null) {
Expand All @@ -698,9 +711,9 @@ public void inferRerank(
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
}

Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? chunkedInputs::get : i -> null;

ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
ActionListener<InferModelAction.Response> mlResultsListener = resultsListener.delegateFailureAndWrap(
(l, inferenceResult) -> l.onResponse(
textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN)
)
Expand Down
Loading