Skip to content

Commit 389b11f

Browse files
Add rerank chunking for JinaAI
1 parent 48f38e5 commit 389b11f

File tree

15 files changed

+992
-39
lines changed

15 files changed

+992
-39
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilder.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,9 @@ public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWo
6767
var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN);
6868
return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1);
6969
}
70+
71+
public static ChunkingSettings buildChunkingSettingsForRerank(int rerankerWindowSize, int queryWordCount) {
72+
var chunkSizeWordCountWithFullQuery = rerankerWindowSize - queryWordCount;
73+
return new SentenceBoundaryChunkingSettings(Math.max(chunkSizeWordCountWithFullQuery, rerankerWindowSize / 2), 1);
74+
}
7075
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunker.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ public RerankRequestChunker(String query, List<String> inputs, Integer maxChunks
2828
this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query), maxChunksPerDoc);
2929
}
3030

31+
public RerankRequestChunker(String query, List<String> inputs, int rerankerWindowSize, Integer maxChunksPerDoc) {
32+
this.inputs = inputs;
33+
this.rerankChunks = chunk(inputs, buildChunkingSettingsForRerank(rerankerWindowSize, query), maxChunksPerDoc);
34+
}
35+
3136
private List<RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings, Integer maxChunksPerDoc) {
3237
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
3338
var chunks = new ArrayList<RerankChunks>();
@@ -53,6 +58,25 @@ public List<String> getChunkedInputs() {
5358
return chunkedInputs;
5459
}
5560

61+
public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(
62+
ActionListener<InferenceServiceResults> listener,
63+
boolean returnDocuments,
64+
Integer topN
65+
) {
66+
listener = parseChunkedRerankResultsListener(listener, returnDocuments);
67+
if (topN != null) {
68+
return listener.delegateFailureAndWrap((l, results) -> {
69+
if (results instanceof RankedDocsResults rankedDocsResults) {
70+
l.onResponse(new RankedDocsResults(rankedDocsResults.getRankedDocs().subList(0, topN)));
71+
} else {
72+
l.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass().getName()));
73+
}
74+
});
75+
} else {
76+
return listener;
77+
}
78+
}
79+
5680
public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(
5781
ActionListener<InferenceServiceResults> listener,
5882
boolean returnDocuments
@@ -101,4 +125,11 @@ private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) {
101125
var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator);
102126
return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount);
103127
}
128+
129+
private ChunkingSettings buildChunkingSettingsForRerank(int rerankerWindowSize, String query) {
130+
var wordIterator = BreakIterator.getWordInstance();
131+
wordIterator.setText(query);
132+
var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator);
133+
return ChunkingSettingsBuilder.buildChunkingSettingsForRerank(rerankerWindowSize, queryWordCount);
134+
}
104135
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,31 @@ public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanH
7777
assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap());
7878
}
7979

80+
public void testBuildChunkingSettingsForRerank_QueryWordCountLessThanHalfOfRerankerWindowSize() {
81+
int rerankerWindowSize = randomIntBetween(20, 6000);
82+
int queryWordCount = randomIntBetween(1, rerankerWindowSize / 2 - 1);
83+
ChunkingSettings actualChunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForRerank(
84+
rerankerWindowSize,
85+
queryWordCount
86+
);
87+
SentenceBoundaryChunkingSettings expectedChunkingSettings = new SentenceBoundaryChunkingSettings(
88+
rerankerWindowSize - queryWordCount,
89+
1
90+
);
91+
assertEquals(expectedChunkingSettings, actualChunkingSettings);
92+
}
93+
94+
public void testBuildChunkingSettingsForRerank_QueryWordCountMoreThanHalfOfRerankerWindowSize() {
95+
int rerankerWindowSize = randomIntBetween(20, 6000);
96+
int queryWordCount = randomIntBetween(rerankerWindowSize / 2, Integer.MAX_VALUE);
97+
ChunkingSettings actualChunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForRerank(
98+
rerankerWindowSize,
99+
queryWordCount
100+
);
101+
SentenceBoundaryChunkingSettings expectedChunkingSettings = new SentenceBoundaryChunkingSettings(rerankerWindowSize / 2, 1);
102+
assertEquals(expectedChunkingSettings, actualChunkingSettings);
103+
}
104+
80105
private Map<Map<String, Object>, ChunkingSettings> chunkingSettingsMapToChunkingSettings() {
81106
var maxChunkSizeWordBoundaryChunkingSettings = randomIntBetween(10, 300);
82107
var overlap = randomIntBetween(1, maxChunkSizeWordBoundaryChunkingSettings / 2);

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunkerTests.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,52 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiring
274274
listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs)));
275275
}
276276

277+
public void testParseChunkedRerankResultsListener_TopNSetToNull() {
278+
var inputs = List.of(generateTestText(10), generateTestText(100));
279+
var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null);
280+
var returnDocuments = randomBoolean();
281+
var relevanceScores = generateRelevanceScores(4);
282+
var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> {
283+
assertThat(results, instanceOf(RankedDocsResults.class));
284+
var rankedDocResults = (RankedDocsResults) results;
285+
var expectedRankedDocs = new ArrayList<RankedDocsResults.RankedDoc>();
286+
expectedRankedDocs.add(new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null));
287+
expectedRankedDocs.add(
288+
new RankedDocsResults.RankedDoc(1, Collections.max(relevanceScores.subList(1, 4)), returnDocuments ? inputs.get(1) : null)
289+
);
290+
expectedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore()));
291+
assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs());
292+
}, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments, null);
293+
294+
var chunkedInputs = chunker.getChunkedInputs();
295+
assertEquals(4, chunkedInputs.size());
296+
listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs)));
297+
}
298+
299+
public void testParseChunkedRerankResultsListener_TopNProvided() {
300+
var inputs = List.of(generateTestText(10), generateTestText(100));
301+
var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null);
302+
var returnDocuments = randomBoolean();
303+
var relevanceScores = generateRelevanceScores(4);
304+
var topN = 1;
305+
var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> {
306+
assertThat(results, instanceOf(RankedDocsResults.class));
307+
var rankedDocResults = (RankedDocsResults) results;
308+
assertEquals(topN, rankedDocResults.getRankedDocs().size());
309+
var expectedRankedDocs = new ArrayList<RankedDocsResults.RankedDoc>();
310+
expectedRankedDocs.add(new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null));
311+
expectedRankedDocs.add(
312+
new RankedDocsResults.RankedDoc(1, Collections.max(relevanceScores.subList(1, 4)), returnDocuments ? inputs.get(1) : null)
313+
);
314+
expectedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore()));
315+
assertEquals(expectedRankedDocs.subList(0, topN), rankedDocResults.getRankedDocs());
316+
}, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments, topN);
317+
318+
var chunkedInputs = chunker.getChunkedInputs();
319+
assertEquals(4, chunkedInputs.size());
320+
listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs)));
321+
}
322+
277323
private String generateTestText(int numSentences) {
278324
StringBuilder sb = new StringBuilder();
279325
for (int i = 0; i < numSentences; i++) {
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.core.TimeValue;
12+
import org.elasticsearch.inference.InferenceServiceResults;
13+
import org.elasticsearch.inference.Model;
14+
import org.elasticsearch.inference.RerankingInferenceService;
15+
import org.elasticsearch.xpack.core.inference.chunking.RerankRequestChunker;
16+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
17+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
18+
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
19+
import org.elasticsearch.xpack.inference.services.settings.LongDocumentStrategy;
20+
import org.elasticsearch.xpack.inference.services.settings.RerankServiceSettings;
21+
22+
public class RerankInferenceProcessor {
23+
public static void doInfer(
24+
SenderService service,
25+
Model model,
26+
ExecutableAction action,
27+
InferenceInputs inputs,
28+
TimeValue timeout,
29+
ActionListener<InferenceServiceResults> listener
30+
) {
31+
var serviceSettings = model.getServiceSettings();
32+
if (serviceSettings instanceof RerankServiceSettings == false) {
33+
throw new IllegalArgumentException("RerankInferenceProcessor can only process models with RerankServiceSettings");
34+
}
35+
36+
var rerankServiceSettings = (RerankServiceSettings) serviceSettings;
37+
if (LongDocumentStrategy.CHUNK.equals(rerankServiceSettings.getLongDocumentStrategy())) {
38+
if (inputs instanceof QueryAndDocsInputs == false) {
39+
throw new IllegalArgumentException("RerankInferenceProcessor can only process QueryAndDocsInputs when chunking is enabled");
40+
}
41+
var queryAndDocsInputs = (QueryAndDocsInputs) inputs;
42+
43+
if (service instanceof RerankingInferenceService == false) {
44+
throw new IllegalArgumentException(
45+
"RerankInferenceProcessor can only process RerankingInferenceService when chunking is enabled"
46+
);
47+
}
48+
49+
var rerankRequestChunker = new RerankRequestChunker(
50+
queryAndDocsInputs.getQuery(),
51+
queryAndDocsInputs.getChunks(),
52+
((RerankingInferenceService) service).rerankerWindowSize(serviceSettings.modelId()),
53+
rerankServiceSettings.getMaxChunksPerDoc()
54+
);
55+
56+
inputs = new QueryAndDocsInputs(
57+
queryAndDocsInputs.getQuery(),
58+
rerankRequestChunker.getChunkedInputs(),
59+
queryAndDocsInputs.getReturnDocuments(), // TODO: Check if we want this from inputs or from task settings
60+
queryAndDocsInputs.getTopN(), // TODO: Check if we want this from inputs or from task settings
61+
queryAndDocsInputs.stream()
62+
);
63+
64+
listener = rerankRequestChunker.parseChunkedRerankResultsListener(
65+
listener,
66+
queryAndDocsInputs.getReturnDocuments() == null || queryAndDocsInputs.getReturnDocuments(),
67+
queryAndDocsInputs.getTopN()
68+
);
69+
}
70+
action.execute(inputs, timeout, listener);
71+
}
72+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.inference.TaskType;
2626
import org.elasticsearch.inference.UnifiedCompletionRequest;
2727
import org.elasticsearch.rest.RestStatus;
28+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2829
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2930
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3031
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -150,13 +151,25 @@ public void chunkedInfer(
150151
}).addListener(listener);
151152
}
152153

153-
protected abstract void doInfer(
154+
protected void doInfer(
154155
Model model,
155156
InferenceInputs inputs,
156157
Map<String, Object> taskSettings,
157158
TimeValue timeout,
158159
ActionListener<InferenceServiceResults> listener
159-
);
160+
) {
161+
var action = createAction(model, taskSettings);
162+
163+
if (Objects.requireNonNull(model.getTaskType()) == TaskType.RERANK) {
164+
RerankInferenceProcessor.doInfer(this, model, action, inputs, timeout, listener);
165+
} else {
166+
action.execute(inputs, timeout, listener);
167+
}
168+
}
169+
170+
protected ExecutableAction createAction(Model model, Map<String, Object> taskSettings) {
171+
return null;
172+
}
160173

161174
protected abstract void validateInputType(InputType inputType, Model model, ValidationException validationException);
162175

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
3333
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
3434
import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker;
35+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
3536
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3637
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
37-
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3838
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
3939
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
4040
import org.elasticsearch.xpack.inference.services.SenderService;
@@ -235,23 +235,14 @@ protected void doUnifiedCompletionInfer(
235235
}
236236

237237
@Override
238-
public void doInfer(
239-
Model model,
240-
InferenceInputs inputs,
241-
Map<String, Object> taskSettings,
242-
TimeValue timeout,
243-
ActionListener<InferenceServiceResults> listener
244-
) {
245-
if (model instanceof JinaAIModel == false) {
246-
listener.onFailure(createInvalidModelException(model));
247-
return;
248-
}
249-
250-
JinaAIModel jinaaiModel = (JinaAIModel) model;
238+
protected ExecutableAction createAction(Model model, Map<String, Object> taskSettings) {
251239
var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents());
252240

253-
var action = jinaaiModel.accept(actionCreator, taskSettings);
254-
action.execute(inputs, timeout, listener);
241+
if (model instanceof JinaAIModel jinaaiModel) {
242+
return jinaaiModel.accept(actionCreator, taskSettings);
243+
} else {
244+
throw createInvalidModelException(model);
245+
}
255246
}
256247

257248
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettings.java

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
2020
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIRateLimitServiceSettings;
2121
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings;
22-
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
2322
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
23+
import org.elasticsearch.xpack.inference.services.settings.RerankServiceSettings;
2424

2525
import java.io.IOException;
2626
import java.util.Map;
2727
import java.util.Objects;
2828

29-
public class JinaAIRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, JinaAIRateLimitServiceSettings {
29+
public class JinaAIRerankServiceSettings extends RerankServiceSettings implements ServiceSettings, JinaAIRateLimitServiceSettings {
3030
public static final String NAME = "jinaai_rerank_service_settings";
3131

3232
private static final Logger logger = LogManager.getLogger(JinaAIRerankServiceSettings.class);
@@ -38,18 +38,26 @@ public static JinaAIRerankServiceSettings fromMap(Map<String, Object> map, Confi
3838
throw validationException;
3939
}
4040

41+
var rerankServiceSettings = RerankServiceSettings.fromMap(map);
4142
var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context);
4243

43-
return new JinaAIRerankServiceSettings(commonServiceSettings);
44+
return new JinaAIRerankServiceSettings(rerankServiceSettings, commonServiceSettings);
4445
}
4546

4647
private final JinaAIServiceSettings commonSettings;
4748

49+
public JinaAIRerankServiceSettings(RerankServiceSettings rerankSettings, JinaAIServiceSettings commonSettings) {
50+
super(rerankSettings);
51+
this.commonSettings = commonSettings;
52+
}
53+
4854
public JinaAIRerankServiceSettings(JinaAIServiceSettings commonSettings) {
55+
super(null, null);
4956
this.commonSettings = commonSettings;
5057
}
5158

5259
public JinaAIRerankServiceSettings(StreamInput in) throws IOException {
60+
super(in);
5361
this.commonSettings = new JinaAIServiceSettings(in);
5462
}
5563

@@ -76,14 +84,16 @@ public String getWriteableName() {
7684
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
7785
builder.startObject();
7886

87+
builder = super.toXContentFragmentOfExposedFields(builder, params);
7988
builder = commonSettings.toXContentFragment(builder, params);
8089

8190
builder.endObject();
8291
return builder;
8392
}
8493

8594
@Override
86-
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
95+
public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
96+
super.toXContentFragmentOfExposedFields(builder, params);
8797
commonSettings.toXContentFragmentOfExposedFields(builder, params);
8898
return builder;
8999
}
@@ -95,6 +105,7 @@ public TransportVersion getMinimalSupportedVersion() {
95105

96106
@Override
97107
public void writeTo(StreamOutput out) throws IOException {
108+
super.writeTo(out);
98109
commonSettings.writeTo(out);
99110
}
100111

@@ -103,11 +114,12 @@ public boolean equals(Object o) {
103114
if (this == o) return true;
104115
if (o == null || getClass() != o.getClass()) return false;
105116
JinaAIRerankServiceSettings that = (JinaAIRerankServiceSettings) o;
106-
return Objects.equals(commonSettings, that.commonSettings);
117+
return super.equals(o) && Objects.equals(commonSettings, that.commonSettings);
107118
}
108119

109120
@Override
110121
public int hashCode() {
111122
return Objects.hash(commonSettings);
123+
// TODO: include super.hashcode?
112124
}
113125
}

0 commit comments

Comments
 (0)