Skip to content

Commit 129efac

Browse files
committed
Refactor method signatures to provide only a single custom input, fix tests so they use highlighting
1 parent 4675586 commit 129efac

File tree

12 files changed

+301
-98
lines changed

12 files changed

+301
-98
lines changed

server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ public void onFailure(Exception e) {
173173
queryResult.getContextId(),
174174
queryResult.getShardSearchRequest(),
175175
entry,
176-
rankFeaturePhaseRankCoordinatorContext.snippets(),
177-
rankFeaturePhaseRankCoordinatorContext.tokenSizeLimit()
176+
rankFeaturePhaseRankCoordinatorContext.customRankInput()
178177
),
179178
context.getTask(),
180179
listener

server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import org.apache.lucene.search.ScoreDoc;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.core.Nullable;
15+
import org.elasticsearch.search.rank.feature.CustomRankInput;
1516
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
16-
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
1717

1818
import java.util.Arrays;
1919
import java.util.Comparator;
@@ -32,7 +32,7 @@ public abstract class RankFeaturePhaseRankCoordinatorContext {
3232
protected final int from;
3333
protected final int rankWindowSize;
3434
protected final boolean failuresAllowed;
35-
protected final RerankSnippetInput snippets;
35+
protected final CustomRankInput customRankInput;
3636

3737
public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean failuresAllowed) {
3838
this(size, from, rankWindowSize, failuresAllowed, null);
@@ -43,31 +43,24 @@ public RankFeaturePhaseRankCoordinatorContext(
4343
int from,
4444
int rankWindowSize,
4545
boolean failuresAllowed,
46-
@Nullable RerankSnippetInput snippets
46+
@Nullable CustomRankInput customRankInput
4747
) {
4848
this.size = size < 0 ? DEFAULT_SIZE : size;
4949
this.from = from < 0 ? DEFAULT_FROM : from;
5050
this.rankWindowSize = rankWindowSize;
5151
this.failuresAllowed = failuresAllowed;
52-
this.snippets = snippets;
52+
this.customRankInput = customRankInput;
5353
}
5454

5555
public boolean failuresAllowed() {
5656
return failuresAllowed;
5757
}
5858

5959
/**
60-
* If non-null, we will rerank based on the best-ranking snippet rather than the whole text.
60+
* If non-null, we will use this custom input when computing reranked results
6161
*/
62-
public RerankSnippetInput snippets() {
63-
return snippets;
64-
}
65-
66-
/**
67-
* If snippets are requested, this should be overridden with the token size limit of the associated model.
68-
*/
69-
public Integer tokenSizeLimit() {
70-
return 0;
62+
public CustomRankInput customRankInput() {
63+
return customRankInput;
7164
}
7265

7366
/**
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.rank.feature;
11+
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
14+
/**
15+
* Defines custom rank input that we send in as input to rank retrievers
16+
*/
17+
public interface CustomRankInput extends Writeable {
18+
19+
/**
20+
* @return unique identifier for this type of input
21+
*/
22+
String name();
23+
}

server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.apache.logging.log4j.LogManager;
1313
import org.apache.logging.log4j.Logger;
14+
import org.elasticsearch.index.query.MatchQueryBuilder;
1415
import org.elasticsearch.search.SearchContextSourcePrinter;
1516
import org.elasticsearch.search.SearchHits;
1617
import org.elasticsearch.search.fetch.FetchSearchResult;
@@ -56,18 +57,22 @@ public static void prepareForFetch(SearchContext searchContext, RankFeatureShard
5657
String field = rankFeaturePhaseRankShardContext.getField();
5758
assert field != null : "field must not be null";
5859
searchContext.fetchFieldsContext(new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(field, null))));
59-
RerankSnippetInput snippets = request.snippets();
60-
if (snippets != null) {
60+
CustomRankInput customRankInput = request.customRankInput();
61+
if (customRankInput instanceof SnippetRankInput snippetRankInput) {
6162
try {
63+
HighlightBuilder highlightBuilder = new HighlightBuilder();
64+
highlightBuilder.highlightQuery(new MatchQueryBuilder(field, snippetRankInput.inferenceText()));
6265
// Stripping pre/post tags as they're not useful for snippet creation
63-
HighlightBuilder highlightBuilder = new HighlightBuilder().field(field).preTags("").postTags("");
66+
highlightBuilder.field(field).preTags("").postTags("");
6467
// Return highest scoring fragments
6568
highlightBuilder.order(HighlightBuilder.Order.SCORE);
66-
int numSnippets = snippets.numSnippets() != null ? snippets.numSnippets() : DEFAULT_NUM_SNIPPETS;
69+
int numSnippets = snippetRankInput.snippets().numSnippets() != null
70+
? snippetRankInput.snippets().numSnippets()
71+
: DEFAULT_NUM_SNIPPETS;
6772
highlightBuilder.numOfFragments(numSnippets);
6873
// Rely on the model to determine the fragment size
6974
// TODO highlighter should be able to set fragment size by token not length
70-
highlightBuilder.fragmentSize(request.getTokenSizeLimit());
75+
highlightBuilder.fragmentSize(snippetRankInput.tokenSizeLimit());
7176
SearchHighlightContext searchHighlightContext = highlightBuilder.build(searchContext.getSearchExecutionContext());
7277
searchContext.highlight(searchHighlightContext);
7378
} catch (IOException e) {

server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,32 +40,29 @@ public class RankFeatureShardRequest extends AbstractTransportRequest implements
4040

4141
private final int[] docIds;
4242

43-
private final RerankSnippetInput snippets;
44-
private final Integer tokenSizeLimit;
43+
private final CustomRankInput customRankInput;
4544

4645
public RankFeatureShardRequest(
4746
OriginalIndices originalIndices,
4847
ShardSearchContextId contextId,
4948
ShardSearchRequest shardSearchRequest,
5049
List<Integer> docIds
5150
) {
52-
this(originalIndices, contextId, shardSearchRequest, docIds, null, null);
51+
this(originalIndices, contextId, shardSearchRequest, docIds, null);
5352
}
5453

5554
public RankFeatureShardRequest(
5655
OriginalIndices originalIndices,
5756
ShardSearchContextId contextId,
5857
ShardSearchRequest shardSearchRequest,
5958
List<Integer> docIds,
60-
@Nullable RerankSnippetInput snippets,
61-
Integer tokenSizeLimit
59+
@Nullable CustomRankInput customRankInput
6260
) {
6361
this.originalIndices = originalIndices;
6462
this.shardSearchRequest = shardSearchRequest;
6563
this.docIds = docIds.stream().flatMapToInt(IntStream::of).toArray();
6664
this.contextId = contextId;
67-
this.snippets = snippets;
68-
this.tokenSizeLimit = tokenSizeLimit;
65+
this.customRankInput = customRankInput;
6966
}
7067

7168
public RankFeatureShardRequest(StreamInput in) throws IOException {
@@ -75,11 +72,14 @@ public RankFeatureShardRequest(StreamInput in) throws IOException {
7572
docIds = in.readIntArray();
7673
contextId = in.readOptionalWriteable(ShardSearchContextId::new);
7774
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
78-
snippets = in.readOptionalWriteable(RerankSnippetInput::new);
79-
this.tokenSizeLimit = in.readOptionalInt();
75+
String name = in.readOptionalString();
76+
if (name != null && name.equals(SnippetRankInput.NAME)) {
77+
customRankInput = new SnippetRankInput(in);
78+
} else {
79+
customRankInput = null;
80+
}
8081
} else {
81-
snippets = null;
82-
this.tokenSizeLimit = null;
82+
customRankInput = null;
8383
}
8484
}
8585

@@ -91,8 +91,11 @@ public void writeTo(StreamOutput out) throws IOException {
9191
out.writeIntArray(docIds);
9292
out.writeOptionalWriteable(contextId);
9393
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
94-
out.writeOptionalWriteable(snippets);
95-
out.writeOptionalInt(tokenSizeLimit);
94+
String name = customRankInput != null ? customRankInput.name() : null;
95+
out.writeOptionalString(name);
96+
if (customRankInput != null) {
97+
customRankInput.writeTo(out);
98+
}
9699
}
97100
}
98101

@@ -124,12 +127,8 @@ public ShardSearchContextId contextId() {
124127
return contextId;
125128
}
126129

127-
public RerankSnippetInput snippets() {
128-
return snippets;
129-
}
130-
131-
public Integer getTokenSizeLimit() {
132-
return tokenSizeLimit;
130+
public CustomRankInput customRankInput() {
131+
return customRankInput;
133132
}
134133

135134
@Override
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.rank.feature;
11+
12+
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
14+
15+
import java.io.IOException;
16+
17+
/**
18+
* Defines a custom rank input to rerank results based on snippets rather than full field contents.
19+
*/
20+
public class SnippetRankInput implements CustomRankInput {
21+
22+
static final String NAME = "snippets";
23+
24+
private final RerankSnippetInput snippets;
25+
private final String inferenceText;
26+
private final int tokenSizeLimit;
27+
28+
public SnippetRankInput(RerankSnippetInput snippets, String inferenceText, int tokenSizeLimit) {
29+
this.snippets = snippets;
30+
this.inferenceText = inferenceText;
31+
this.tokenSizeLimit = tokenSizeLimit;
32+
}
33+
34+
public SnippetRankInput(StreamInput in) throws IOException {
35+
this.snippets = new RerankSnippetInput(in);
36+
this.inferenceText = in.readString();
37+
this.tokenSizeLimit = in.readVInt();
38+
}
39+
40+
public RerankSnippetInput snippets() {
41+
return snippets;
42+
}
43+
44+
public String inferenceText() {
45+
return inferenceText;
46+
}
47+
48+
public Integer tokenSizeLimit() {
49+
return tokenSizeLimit;
50+
}
51+
52+
@Override
53+
public void writeTo(StreamOutput out) throws IOException {
54+
snippets.writeTo(out);
55+
out.writeString(inferenceText);
56+
out.writeVInt(tokenSizeLimit);
57+
}
58+
59+
@Override
60+
public String name() {
61+
return NAME;
62+
}
63+
}

server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.logging.log4j.LogManager;
1313
import org.apache.logging.log4j.Logger;
1414
import org.elasticsearch.common.document.DocumentField;
15+
import org.elasticsearch.common.logging.HeaderWarning;
1516
import org.elasticsearch.search.SearchHit;
1617
import org.elasticsearch.search.SearchHits;
1718
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
@@ -57,11 +58,18 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
5758
rankFeatureDocs[i].featureData(List.of(docField.getValue().toString()));
5859
} else {
5960
Map<String, HighlightField> highlightFields = hit.getHighlightFields();
60-
if (highlightFields != null) {
61-
if (highlightFields.containsKey(field)) {
62-
List<String> snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList();
63-
rankFeatureDocs[i].featureData(snippets);
64-
}
61+
if (highlightFields != null && highlightFields.containsKey(field)) {
62+
List<String> snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList();
63+
rankFeatureDocs[i].featureData(snippets);
64+
} else if (docField != null) {
65+
// If we did not get highlighting results, backfill with the doc field value
66+
// but pass in a warning because we are not reranking on snippets only
67+
rankFeatureDocs[i].featureData(List.of(docField.getValue().toString()));
68+
HeaderWarning.addWarning(
69+
"Reranking on snippets requested, but no snippets were found for field ["
70+
+ field
71+
+ "]. Using field value instead."
72+
);
6573
}
6674
}
6775
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
2525
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
2626
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
27+
import org.elasticsearch.search.rank.feature.SnippetRankInput;
2728
import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext;
2829
import org.elasticsearch.xcontent.XContentBuilder;
2930

@@ -37,6 +38,8 @@
3738
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD;
3839
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD;
3940
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.SNIPPETS_FIELD;
41+
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.DEFAULT_RERANK_ID;
42+
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID;
4043

4144
/**
4245
* A {@code RankBuilder} that enables ranking with text similarity model inference. Supports parameters for configuring the inference call.
@@ -45,6 +48,17 @@ public class TextSimilarityRankBuilder extends RankBuilder {
4548

4649
public static final String NAME = "text_similarity_reranker";
4750

51+
/**
52+
* The default token size limit of the Elastic reranker.
53+
*/
54+
private static final int RERANK_TOKEN_SIZE_LIMIT = 512;
55+
56+
/**
57+
* A safe default token size limit for other reranker models.
58+
* Reranker models with smaller token limits will be truncated.
59+
*/
60+
private static final int DEFAULT_TOKEN_SIZE_LIMIT = 4096;
61+
4862
public static final LicensedFeature.Momentary TEXT_SIMILARITY_RERANKER_FEATURE = LicensedFeature.momentary(
4963
null,
5064
"text-similarity-reranker",
@@ -198,10 +212,22 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
198212
inferenceText,
199213
minScore,
200214
failuresAllowed,
201-
snippets
215+
snippets != null ? new SnippetRankInput(snippets, inferenceText, tokenSizeLimit()) : null
202216
);
203217
}
204218

219+
/**
220+
* @return The token size limit to apply to this rerank context.
221+
* This is not yet available so we are hardcoding it for now.
222+
*/
223+
public Integer tokenSizeLimit() {
224+
if (inferenceId.equals(DEFAULT_RERANK_ID) || inferenceId.equals(RERANKER_ID)) {
225+
return RERANK_TOKEN_SIZE_LIMIT;
226+
}
227+
228+
return DEFAULT_TOKEN_SIZE_LIMIT;
229+
}
230+
205231
public String field() {
206232
return field;
207233
}

0 commit comments

Comments
 (0)