Skip to content

Commit 849147e

Browse files
Add RerankRequestChunker
1 parent 47eada6 commit 849147e

File tree

3 files changed

+112
-2
lines changed

3 files changed

+112
-2
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ protected void doInference(
7474
InferenceService service,
7575
ActionListener<InferenceServiceResults> listener
7676
) {
77+
// var rerankChunker = new RerankRequestChunker(request.getInput());
78+
7779
service.infer(
7880
model,
7981
request.getQuery(),
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.chunking;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.inference.ChunkingSettings;
12+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
13+
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
14+
15+
import java.util.ArrayList;
16+
import java.util.HashMap;
17+
import java.util.List;
18+
import java.util.Map;
19+
20+
public class RerankRequestChunker {
21+
22+
private final ChunkingSettings chunkingSettings;
23+
private final List<String> inputs;
24+
private final Map<Integer, RerankChunks> rerankChunks;
25+
26+
public RerankRequestChunker(List<String> inputs) {
27+
// TODO: Make chunking settings dependent on the model being used.
28+
// There may be a way to do this dynamically knowing the max token size for the model/service and query size
29+
// instead of hardcoding it ona model/service basis.
30+
this.chunkingSettings = new WordBoundaryChunkingSettings(100, 10);
31+
this.inputs = inputs;
32+
this.rerankChunks = chunk(inputs, chunkingSettings);
33+
}
34+
35+
private Map<Integer, RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings) {
36+
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
37+
var chunks = new HashMap<Integer, RerankChunks>();
38+
var chunkIndex = 0;
39+
for (int i = 0; i < inputs.size(); i++) {
40+
var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings);
41+
for (var chunk : chunksForInput) {
42+
chunks.put(chunkIndex, new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end())));
43+
chunkIndex++;
44+
}
45+
}
46+
return chunks;
47+
}
48+
49+
public List<String> getChunkedInputs() {
50+
List<String> chunkedInputs = new ArrayList<>();
51+
for (RerankChunks chunk : rerankChunks.values()) {
52+
chunkedInputs.add(chunk.chunkString());
53+
}
54+
// TODO: Score the inputs here and only return the top N chunks for each document
55+
return chunkedInputs;
56+
}
57+
58+
public ActionListener<InferenceAction.Response> parseChunkedRerankResultsListener(ActionListener<InferenceAction.Response> listener) {
59+
return ActionListener.wrap(results -> {
60+
if (results.getResults() instanceof RankedDocsResults rankedDocsResults) {
61+
listener.onResponse(new InferenceAction.Response(parseRankedDocResultsForChunks(rankedDocsResults)));
62+
// TODO: Figure out if the above correctly creates the response or if it loses any info
63+
64+
} else {
65+
listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass()));
66+
}
67+
68+
}, listener::onFailure);
69+
}
70+
71+
private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) {
72+
Map<Integer, RankedDocsResults.RankedDoc> bestRankedDocResultPerDoc = new HashMap<>();
73+
for (var rankedDoc : rankedDocsResults.getRankedDocs()) {
74+
int chunkIndex = rankedDoc.index();
75+
int docIndex = rerankChunks.get(chunkIndex).docIndex();
76+
if (bestRankedDocResultPerDoc.containsKey(docIndex)) {
77+
RankedDocsResults.RankedDoc existingDoc = bestRankedDocResultPerDoc.get(docIndex);
78+
if (rankedDoc.relevanceScore() > existingDoc.relevanceScore()) {
79+
bestRankedDocResultPerDoc.put(
80+
docIndex,
81+
new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex))
82+
);
83+
}
84+
} else {
85+
bestRankedDocResultPerDoc.put(
86+
docIndex,
87+
new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex))
88+
);
89+
}
90+
}
91+
var bestRankedDocResultPerDocList = new ArrayList<>(bestRankedDocResultPerDoc.values());
92+
bestRankedDocResultPerDocList.sort(
93+
(RankedDocsResults.RankedDoc d1, RankedDocsResults.RankedDoc d2) -> Float.compare(d2.relevanceScore(), d1.relevanceScore())
94+
);
95+
return new RankedDocsResults(bestRankedDocResultPerDocList);
96+
}
97+
98+
public record RerankChunks(int docIndex, String chunkString) {};
99+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
1818
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1919
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
20+
import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker;
2021
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
2122
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
2223
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
@@ -119,9 +120,16 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
119120
inferenceListener.onResponse(new InferenceAction.Response(new RankedDocsResults(List.of())));
120121
} else {
121122
List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
122-
InferenceAction.Request inferenceRequest = generateRequest(featureData);
123+
RerankRequestChunker chunker = new RerankRequestChunker(featureData);
124+
InferenceAction.Request inferenceRequest = generateRequest(chunker.getChunkedInputs());
123125
try {
124-
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceRequest, inferenceListener);
126+
executeAsyncWithOrigin(
127+
client,
128+
INFERENCE_ORIGIN,
129+
InferenceAction.INSTANCE,
130+
inferenceRequest,
131+
chunker.parseChunkedRerankResultsListener(inferenceListener)
132+
);
125133
} finally {
126134
inferenceRequest.decRef();
127135
}
@@ -156,6 +164,7 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rer
156164
}
157165

158166
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
167+
// TODO: Try running the RerankRequestChunker here.
159168
return new InferenceAction.Request(
160169
TaskType.RERANK,
161170
inferenceId,

0 commit comments

Comments
 (0)