Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
849147e
Add RerankRequestChunker
dan-rubinstein Jun 10, 2025
c41d54c
Merge branch 'main' into rerank-request-chunker
elasticmachine Jul 3, 2025
da4c939
Add chunking strategy generation
dan-rubinstein Jul 4, 2025
004ca8f
Merge branch 'main' into rerank-request-chunker
davidkyle Jul 18, 2025
5ec620a
Merge branch 'main' into rerank-request-chunker
elasticmachine Jul 30, 2025
4ff8eb0
Adding unit tests and fixing token/word ratio
dan-rubinstein Jul 23, 2025
ec78b87
Merge branch 'main' into rerank-request-chunker
elasticmachine Aug 13, 2025
9ef8917
Add configurable values for long document handling strategy and maxim…
dan-rubinstein Sep 8, 2025
24497ae
Adding back sentence overlap for rerank chunking strategy
dan-rubinstein Sep 11, 2025
1fea365
Merge branch 'main' into rerank-request-chunker
elasticmachine Sep 11, 2025
8396214
Merge branch 'main' into rerank-request-chunker
elasticmachine Sep 22, 2025
8b97711
Adding unit tests, transport version, and feature flag
dan-rubinstein Sep 18, 2025
833ef02
Update docs/changelog/130485.yaml
dan-rubinstein Sep 22, 2025
77701e1
Merge branch 'main' of github.com:elastic/elasticsearch into rerank-r…
dan-rubinstein Sep 25, 2025
344e121
Adding unit tests and refactoring code with clearer naming conventions
dan-rubinstein Sep 25, 2025
02c9d0a
Merge branch 'main' of github.com:elastic/elasticsearch into rerank-r…
dan-rubinstein Sep 29, 2025
d68bf09
Merge branch 'main' of github.com:elastic/elasticsearch into rerank-r…
dan-rubinstein Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ protected void doInference(
InferenceService service,
ActionListener<InferenceServiceResults> listener
) {
// var rerankChunker = new RerankRequestChunker(request.getInput());

service.infer(
model,
request.getQuery(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class RerankRequestChunker {

private final ChunkingSettings chunkingSettings;
private final List<String> inputs;
private final Map<Integer, RerankChunks> rerankChunks;

public RerankRequestChunker(List<String> inputs) {
// TODO: Make chunking settings dependent on the model being used.
// There may be a way to do this dynamically knowing the max token size for the model/service and query size
// instead of hardcoding it ona model/service basis.
this.chunkingSettings = new WordBoundaryChunkingSettings(100, 10);
this.inputs = inputs;
this.rerankChunks = chunk(inputs, chunkingSettings);
}

private Map<Integer, RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings) {
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
var chunks = new HashMap<Integer, RerankChunks>();
var chunkIndex = 0;
for (int i = 0; i < inputs.size(); i++) {
var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings);
for (var chunk : chunksForInput) {
chunks.put(chunkIndex, new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end())));
chunkIndex++;
}
}
return chunks;
}

public List<String> getChunkedInputs() {
List<String> chunkedInputs = new ArrayList<>();
for (RerankChunks chunk : rerankChunks.values()) {
chunkedInputs.add(chunk.chunkString());
}
// TODO: Score the inputs here and only return the top N chunks for each document
return chunkedInputs;
}

public ActionListener<InferenceAction.Response> parseChunkedRerankResultsListener(ActionListener<InferenceAction.Response> listener) {
return ActionListener.wrap(results -> {
if (results.getResults() instanceof RankedDocsResults rankedDocsResults) {
listener.onResponse(new InferenceAction.Response(parseRankedDocResultsForChunks(rankedDocsResults)));
// TODO: Figure out if the above correctly creates the response or if it loses any info

} else {
listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass()));
}

}, listener::onFailure);
}

private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) {
Map<Integer, RankedDocsResults.RankedDoc> bestRankedDocResultPerDoc = new HashMap<>();
for (var rankedDoc : rankedDocsResults.getRankedDocs()) {
int chunkIndex = rankedDoc.index();
int docIndex = rerankChunks.get(chunkIndex).docIndex();
if (bestRankedDocResultPerDoc.containsKey(docIndex)) {
RankedDocsResults.RankedDoc existingDoc = bestRankedDocResultPerDoc.get(docIndex);
if (rankedDoc.relevanceScore() > existingDoc.relevanceScore()) {
bestRankedDocResultPerDoc.put(
docIndex,
new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex))
);
}
} else {
bestRankedDocResultPerDoc.put(
docIndex,
new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex))
);
}
}
var bestRankedDocResultPerDocList = new ArrayList<>(bestRankedDocResultPerDoc.values());
bestRankedDocResultPerDocList.sort(
(RankedDocsResults.RankedDoc d1, RankedDocsResults.RankedDoc d2) -> Float.compare(d2.relevanceScore(), d1.relevanceScore())
);
return new RankedDocsResults(bestRankedDocResultPerDocList);
}

public record RerankChunks(int docIndex, String chunkString) {};
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
Expand Down Expand Up @@ -119,9 +120,16 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
inferenceListener.onResponse(new InferenceAction.Response(new RankedDocsResults(List.of())));
} else {
List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
InferenceAction.Request inferenceRequest = generateRequest(featureData);
RerankRequestChunker chunker = new RerankRequestChunker(featureData);
InferenceAction.Request inferenceRequest = generateRequest(chunker.getChunkedInputs());
try {
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceRequest, inferenceListener);
executeAsyncWithOrigin(
client,
INFERENCE_ORIGIN,
InferenceAction.INSTANCE,
inferenceRequest,
chunker.parseChunkedRerankResultsListener(inferenceListener)
);
} finally {
inferenceRequest.decRef();
}
Expand Down Expand Up @@ -156,6 +164,7 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rer
}

protected InferenceAction.Request generateRequest(List<String> docFeatures) {
// TODO: Try running the RerankRequestChunker here.
return new InferenceAction.Request(
TaskType.RERANK,
inferenceId,
Expand Down