Skip to content

Commit f1cd6d7

Browse files
authored
Optionally allow text similarity reranking to fail (#121784) (#128250)
Backport #121784 to 8.19
1 parent 1ba61eb commit f1cd6d7

File tree

24 files changed

+481
-351
lines changed

24 files changed

+481
-351
lines changed

docs/changelog/121784.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 121784
2+
summary: Optionally allow text similarity reranking to fail
3+
area: Search
4+
type: enhancement
5+
issues: []

server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
205205

206206
@Override
207207
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
208-
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) {
208+
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), false) {
209209
@Override
210210
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
211211
float[] scores = new float[featureDocs.length];
@@ -346,7 +346,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
346346
@Override
347347
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
348348
if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT)
349-
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) {
349+
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), false) {
350350
@Override
351351
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
352352
throw new UnsupportedOperationException("rfc - simulated failure");

server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ public static class TestRerankingRankFeaturePhaseRankCoordinatorContext extends
249249
String inferenceText,
250250
float minScore
251251
) {
252-
super(size, from, windowSize);
252+
super(size, from, windowSize, false);
253253
this.client = client;
254254
this.inferenceId = inferenceId;
255255
this.inferenceText = inferenceText;

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ static TransportVersion def(int id) {
223223
public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32);
224224
public static final TransportVersion INCLUDE_INDEX_MODE_IN_GET_DATA_STREAM_BACKPORT_8_19 = def(8_841_0_33);
225225
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
226+
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
226227

227228
/*
228229
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
import org.elasticsearch.search.builder.SearchSourceBuilder;
2121
import org.elasticsearch.search.dfs.AggregatedDfs;
2222
import org.elasticsearch.search.internal.ShardSearchContextId;
23+
import org.elasticsearch.search.rank.RankDoc;
2324
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
2425
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
2526
import org.elasticsearch.search.rank.feature.RankFeatureResult;
2627
import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
2728
import org.elasticsearch.transport.Transport;
2829

30+
import java.util.Arrays;
2931
import java.util.List;
3032

3133
/**
@@ -186,7 +188,7 @@ private void onPhaseDone(
186188
new ActionListener<>() {
187189
@Override
188190
public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {
189-
RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores);
191+
RankDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores, true);
190192
SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults(
191193
reducedQueryPhase,
192194
topResults
@@ -196,12 +198,36 @@ public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {
196198

197199
@Override
198200
public void onFailure(Exception e) {
199-
context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e);
201+
if (rankFeaturePhaseRankCoordinatorContext.failuresAllowed()) {
202+
// TODO: handle the exception somewhere
203+
// don't want to log the entire stack trace, it's not helpful here
204+
logger.warn("Exception computing updated ranks, continuing with existing ranks: {}", e.toString());
205+
// use the existing score docs as-is
206+
// downstream things expect every doc to have a score, so we need to infer a score here
207+
// if the doc doesn't otherwise have one. We can use the rank to infer a possible score instead (1/rank).
208+
ScoreDoc[] inputDocs = reducedQueryPhase.sortedTopDocs().scoreDocs();
209+
RankFeatureDoc[] rankDocs = new RankFeatureDoc[inputDocs.length];
210+
for (int i = 0; i < inputDocs.length; i++) {
211+
ScoreDoc doc = inputDocs[i];
212+
rankDocs[i] = new RankFeatureDoc(doc.doc, Float.isNaN(doc.score) ? 1f / (i + 1) : doc.score, doc.shardIndex);
213+
}
214+
RankDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(rankDocs, false);
215+
SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults(
216+
reducedQueryPhase,
217+
topResults
218+
);
219+
moveToNextPhase(rankPhaseResults, reducedRankFeaturePhase);
220+
} else {
221+
context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e);
222+
}
200223
}
201224
}
202225
);
203226
rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(
204-
rankPhaseResults.getAtomicArray().asList().stream().map(SearchPhaseResult::rankFeatureResult).toList(),
227+
rankPhaseResults.getSuccessfulResults()
228+
.flatMap(r -> Arrays.stream(r.rankFeatureResult().shardResult().rankFeatureDocs))
229+
.filter(rfd -> rfd.featureData != null)
230+
.toArray(RankFeatureDoc[]::new),
205231
rankResultListener
206232
);
207233
}
@@ -210,7 +236,6 @@ private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults(
210236
SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
211237
ScoreDoc[] scoreDocs
212238
) {
213-
214239
return new SearchPhaseController.ReducedQueryPhase(
215240
reducedQueryPhase.totalHits(),
216241
reducedQueryPhase.fetchHits(),

server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,8 @@ public final boolean equals(Object obj) {
117117
if (obj == null || getClass() != obj.getClass()) {
118118
return false;
119119
}
120-
@SuppressWarnings("unchecked")
121120
RankBuilder other = (RankBuilder) obj;
122-
return Objects.equals(rankWindowSize, other.rankWindowSize()) && doEquals(other);
121+
return rankWindowSize == other.rankWindowSize && doEquals(other);
123122
}
124123

125124
protected abstract boolean doEquals(RankBuilder other);

server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,9 @@
1212
import org.apache.lucene.search.ScoreDoc;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
15-
import org.elasticsearch.search.rank.feature.RankFeatureResult;
16-
import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
1715

18-
import java.util.ArrayList;
1916
import java.util.Arrays;
2017
import java.util.Comparator;
21-
import java.util.List;
2218

2319
import static org.elasticsearch.search.SearchService.DEFAULT_FROM;
2420
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
@@ -33,11 +29,17 @@ public abstract class RankFeaturePhaseRankCoordinatorContext {
3329
protected final int size;
3430
protected final int from;
3531
protected final int rankWindowSize;
32+
protected final boolean failuresAllowed;
3633

37-
public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
34+
public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean failuresAllowed) {
3835
this.size = size < 0 ? DEFAULT_SIZE : size;
3936
this.from = from < 0 ? DEFAULT_FROM : from;
4037
this.rankWindowSize = rankWindowSize;
38+
this.failuresAllowed = failuresAllowed;
39+
}
40+
41+
public boolean failuresAllowed() {
42+
return failuresAllowed;
4143
}
4244

4345
/**
@@ -48,12 +50,13 @@ public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindow
4850

4951
/**
5052
* Preprocesses the provided documents: sorts them by score descending.
51-
* @param originalDocs documents to process
53+
*
54+
* @param originalDocs documents to process
55+
* @param rerankedScores {@code true} if the document scores have been reranked
5256
*/
53-
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
54-
return Arrays.stream(originalDocs)
55-
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
56-
.toArray(RankFeatureDoc[]::new);
57+
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rerankedScores) {
58+
Arrays.sort(originalDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed());
59+
return originalDocs;
5760
}
5861

5962
/**
@@ -64,16 +67,10 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
6467
* Once all the scores have been computed, we sort the results, perform any pagination needed, and then call the `onFinish` consumer
6568
* with the final array of {@link ScoreDoc} results.
6669
*
67-
* @param rankSearchResults a list of rank feature results from each shard
70+
* @param featureDocs an array of rank feature results from each shard
6871
* @param rankListener a rankListener to handle the global ranking result
6972
*/
70-
public void computeRankScoresForGlobalResults(
71-
List<RankFeatureResult> rankSearchResults,
72-
ActionListener<RankFeatureDoc[]> rankListener
73-
) {
74-
// extract feature data from each shard rank-feature phase result
75-
RankFeatureDoc[] featureDocs = extractFeatureDocs(rankSearchResults);
76-
73+
public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, ActionListener<RankFeatureDoc[]> rankListener) {
7774
// generate the final `topResults` results, and pass them to fetch phase through the `rankListener`
7875
computeScores(featureDocs, rankListener.delegateFailureAndWrap((listener, scores) -> {
7976
for (int i = 0; i < featureDocs.length; i++) {
@@ -86,28 +83,17 @@ public void computeRankScoresForGlobalResults(
8683
/**
8784
* Ranks the provided {@link RankFeatureDoc} array and paginates the results based on the `from` and `size` parameters. Filters out
8885
* documents that have a relevance score less than min_score.
86+
*
8987
* @param rankFeatureDocs documents to process
88+
* @param rerankedScores {@code true} if the document scores have been reranked
9089
*/
91-
public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) {
92-
RankFeatureDoc[] sortedDocs = preprocess(rankFeatureDocs);
90+
public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs, boolean rerankedScores) {
91+
RankFeatureDoc[] sortedDocs = preprocess(rankFeatureDocs, rerankedScores);
9392
RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, sortedDocs.length - from))];
9493
for (int rank = 0; rank < topResults.length; ++rank) {
9594
topResults[rank] = sortedDocs[from + rank];
9695
topResults[rank].rank = from + rank + 1;
9796
}
9897
return topResults;
9998
}
100-
101-
private RankFeatureDoc[] extractFeatureDocs(List<RankFeatureResult> rankSearchResults) {
102-
List<RankFeatureDoc> docFeatures = new ArrayList<>();
103-
for (RankFeatureResult rankFeatureResult : rankSearchResults) {
104-
RankFeatureShardResult shardResult = rankFeatureResult.shardResult();
105-
for (RankFeatureDoc rankFeatureDoc : shardResult.rankFeatureDocs) {
106-
if (rankFeatureDoc.featureData != null) {
107-
docFeatures.add(rankFeatureDoc);
108-
}
109-
}
110-
}
111-
return docFeatures.toArray(new RankFeatureDoc[0]);
112-
}
11399
}

server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ public void sendExecuteRankFeature(
775775
}
776776

777777
private RankFeaturePhaseRankCoordinatorContext defaultRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
778-
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize) {
778+
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize, false) {
779779

780780
@Override
781781
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
@@ -785,20 +785,12 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
785785
}
786786

787787
@Override
788-
public void computeRankScoresForGlobalResults(
789-
List<RankFeatureResult> rankSearchResults,
790-
ActionListener<RankFeatureDoc[]> rankListener
791-
) {
792-
List<RankFeatureDoc> features = new ArrayList<>();
793-
for (RankFeatureResult rankFeatureResult : rankSearchResults) {
794-
RankFeatureShardResult shardResult = rankFeatureResult.shardResult();
795-
features.addAll(Arrays.stream(shardResult.rankFeatureDocs).toList());
796-
}
797-
rankListener.onResponse(features.toArray(new RankFeatureDoc[0]));
788+
public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, ActionListener<RankFeatureDoc[]> rankListener) {
789+
rankListener.onResponse(featureDocs);
798790
}
799791

800792
@Override
801-
public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) {
793+
public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs, boolean rerankedScores) {
802794
Arrays.sort(rankFeatureDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed());
803795
RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, rankFeatureDocs.length - from))];
804796
// perform pagination

server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
693693
int from,
694694
Client client
695695
) {
696-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
696+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
697697
@Override
698698
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
699699
float[] scores = new float[featureDocs.length];
@@ -837,7 +837,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
837837
int from,
838838
Client client
839839
) {
840-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
840+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
841841
@Override
842842
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
843843
throw new IllegalStateException("should have failed earlier");
@@ -953,7 +953,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
953953
int from,
954954
Client client
955955
) {
956-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
956+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
957957
@Override
958958
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
959959
float[] scores = new float[featureDocs.length];
@@ -1081,7 +1081,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
10811081
int from,
10821082
Client client
10831083
) {
1084-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
1084+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
10851085
@Override
10861086
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
10871087
float[] scores = new float[featureDocs.length];

server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
171171
// no work to be done on the coordinator node for the rank feature phase
172172
@Override
173173
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
174-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
174+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
175175
@Override
176176
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
177177
throw new AssertionError("not expected");

0 commit comments

Comments
 (0)