Skip to content

Commit ad91d47

Browse files
committed
Add a lenient option to text similarity reranking
1 parent c8053d4 commit ad91d47

File tree

19 files changed

+103
-79
lines changed

19 files changed

+103
-79
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ static TransportVersion def(int id) {
177177
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
178178
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES = def(9_002_0_00);
179179
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED = def(9_003_0_00);
180-
181180
public static final TransportVersion REMOVE_DESIRED_NODE_VERSION = def(9_004_0_00);
181+
public static final TransportVersion LENIENT_RERANKERS = def(9_005_0_00);
182182

183183
/*
184184
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
2727
import org.elasticsearch.transport.Transport;
2828

29+
import java.util.Arrays;
2930
import java.util.List;
3031

3132
/**
@@ -181,6 +182,11 @@ private void onPhaseDone(
181182
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext,
182183
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
183184
) {
185+
RankFeatureDoc[] docs = rankPhaseResults.getSuccessfulResults()
186+
.flatMap(r -> Arrays.stream(r.rankFeatureResult().shardResult().rankFeatureDocs))
187+
.filter(rfd -> rfd.featureData != null)
188+
.toArray(RankFeatureDoc[]::new);
189+
184190
ThreadedActionListener<RankFeatureDoc[]> rankResultListener = new ThreadedActionListener<>(
185191
context::execute,
186192
new ActionListener<>() {
@@ -196,21 +202,26 @@ public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {
196202

197203
@Override
198204
public void onFailure(Exception e) {
199-
context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e);
205+
if (rankFeaturePhaseRankCoordinatorContext.isLenient()) {
206+
// TODO: handle the exception somewhere
207+
logger.warn("Exception computing updated ranks. Continuing with existing ranks.", e);
208+
// use the existing docs as-is
209+
// AbstractThreadedActionListener forks onFailure to the same executor as onResponse,
210+
// so we can just call this direct
211+
onResponse(docs);
212+
} else {
213+
context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e);
214+
}
200215
}
201216
}
202217
);
203-
rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(
204-
rankPhaseResults.getAtomicArray().asList().stream().map(SearchPhaseResult::rankFeatureResult).toList(),
205-
rankResultListener
206-
);
218+
rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(docs, rankResultListener);
207219
}
208220

209221
private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults(
210222
SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
211223
ScoreDoc[] scoreDocs
212224
) {
213-
214225
return new SearchPhaseController.ReducedQueryPhase(
215226
reducedQueryPhase.totalHits(),
216227
reducedQueryPhase.fetchHits(),

server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.apache.lucene.search.Explanation;
1313
import org.apache.lucene.search.Query;
14+
import org.elasticsearch.TransportVersions;
1415
import org.elasticsearch.client.internal.Client;
1516
import org.elasticsearch.common.Strings;
1617
import org.elasticsearch.common.io.stream.StreamInput;
@@ -42,21 +43,32 @@
4243
public abstract class RankBuilder implements VersionedNamedWriteable, ToXContentObject {
4344

4445
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
46+
public static final ParseField LENIENT_FIELD = new ParseField("lenient");
4547

4648
public static final int DEFAULT_RANK_WINDOW_SIZE = SearchService.DEFAULT_SIZE;
4749

4850
private final int rankWindowSize;
51+
private final boolean lenient;
4952

50-
public RankBuilder(int rankWindowSize) {
53+
public RankBuilder(int rankWindowSize, boolean lenient) {
5154
this.rankWindowSize = rankWindowSize;
55+
this.lenient = lenient;
5256
}
5357

5458
public RankBuilder(StreamInput in) throws IOException {
5559
rankWindowSize = in.readVInt();
60+
if (in.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) {
61+
lenient = in.readBoolean();
62+
} else {
63+
lenient = false;
64+
}
5665
}
5766

5867
public final void writeTo(StreamOutput out) throws IOException {
5968
out.writeVInt(rankWindowSize);
69+
if (out.getTransportVersion().onOrAfter(TransportVersions.LENIENT_RERANKERS)) {
70+
out.writeBoolean(lenient);
71+
}
6072
doWriteTo(out);
6173
}
6274

@@ -67,6 +79,9 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params)
6779
builder.startObject();
6880
builder.startObject(getWriteableName());
6981
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
82+
if (lenient) {
83+
builder.field(LENIENT_FIELD.getPreferredName(), lenient);
84+
}
7085
doXContent(builder, params);
7186
builder.endObject();
7287
builder.endObject();
@@ -79,6 +94,10 @@ public int rankWindowSize() {
7994
return rankWindowSize;
8095
}
8196

97+
public boolean isLenient() {
98+
return lenient;
99+
}
100+
82101
/**
83102
* Specify whether this rank builder is a compound builder or not. A compound builder is a rank builder that requires
84103
* two or more queries to be executed in order to generate the final result.

server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,17 @@ public abstract class RankFeaturePhaseRankCoordinatorContext {
3333
protected final int size;
3434
protected final int from;
3535
protected final int rankWindowSize;
36+
protected final boolean lenient;
3637

37-
public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
38+
public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean lenient) {
3839
this.size = size < 0 ? DEFAULT_SIZE : size;
3940
this.from = from < 0 ? DEFAULT_FROM : from;
4041
this.rankWindowSize = rankWindowSize;
42+
this.lenient = lenient;
43+
}
44+
45+
public boolean isLenient() {
46+
return lenient;
4147
}
4248

4349
/**
@@ -51,9 +57,9 @@ public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindow
5157
* @param originalDocs documents to process
5258
*/
5359
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
54-
return Arrays.stream(originalDocs)
55-
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
56-
.toArray(RankFeatureDoc[]::new);
60+
RankFeatureDoc[] sorted = originalDocs.clone();
61+
Arrays.sort(sorted, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed());
62+
return sorted;
5763
}
5864

5965
/**
@@ -64,16 +70,10 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
6470
* Once all the scores have been computed, we sort the results, perform any pagination needed, and then call the `onFinish` consumer
6571
* with the final array of {@link ScoreDoc} results.
6672
*
67-
* @param rankSearchResults a list of rank feature results from each shard
73+
* @param featureDocs an array of rank feature results from each shard
6874
* @param rankListener a rankListener to handle the global ranking result
6975
*/
70-
public void computeRankScoresForGlobalResults(
71-
List<RankFeatureResult> rankSearchResults,
72-
ActionListener<RankFeatureDoc[]> rankListener
73-
) {
74-
// extract feature data from each shard rank-feature phase result
75-
RankFeatureDoc[] featureDocs = extractFeatureDocs(rankSearchResults);
76-
76+
public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, ActionListener<RankFeatureDoc[]> rankListener) {
7777
// generate the final `topResults` results, and pass them to fetch phase through the `rankListener`
7878
computeScores(featureDocs, rankListener.delegateFailureAndWrap((listener, scores) -> {
7979
for (int i = 0; i < featureDocs.length; i++) {

server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ public void sendExecuteRankFeature(
775775
}
776776

777777
private RankFeaturePhaseRankCoordinatorContext defaultRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
778-
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize) {
778+
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize, false) {
779779

780780
@Override
781781
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
@@ -785,16 +785,8 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
785785
}
786786

787787
@Override
788-
public void computeRankScoresForGlobalResults(
789-
List<RankFeatureResult> rankSearchResults,
790-
ActionListener<RankFeatureDoc[]> rankListener
791-
) {
792-
List<RankFeatureDoc> features = new ArrayList<>();
793-
for (RankFeatureResult rankFeatureResult : rankSearchResults) {
794-
RankFeatureShardResult shardResult = rankFeatureResult.shardResult();
795-
features.addAll(Arrays.stream(shardResult.rankFeatureDocs).toList());
796-
}
797-
rankListener.onResponse(features.toArray(new RankFeatureDoc[0]));
788+
public void computeRankScoresForGlobalResults(RankFeatureDoc[] featureDocs, ActionListener<RankFeatureDoc[]> rankListener) {
789+
rankListener.onResponse(featureDocs);
798790
}
799791

800792
@Override
@@ -875,7 +867,7 @@ private RankBuilder rankBuilder(
875867
RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext,
876868
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext
877869
) {
878-
return new RankBuilder(rankWindowSize) {
870+
return new RankBuilder(rankWindowSize, false) {
879871
@Override
880872
protected void doWriteTo(StreamOutput out) throws IOException {
881873
// no-op

server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
687687
int from,
688688
Client client
689689
) {
690-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
690+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
691691
@Override
692692
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
693693
float[] scores = new float[featureDocs.length];
@@ -831,7 +831,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
831831
int from,
832832
Client client
833833
) {
834-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
834+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
835835
@Override
836836
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
837837
throw new IllegalStateException("should have failed earlier");
@@ -947,7 +947,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
947947
int from,
948948
Client client
949949
) {
950-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
950+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
951951
@Override
952952
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
953953
float[] scores = new float[featureDocs.length];
@@ -1075,7 +1075,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
10751075
int from,
10761076
Client client
10771077
) {
1078-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
1078+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
10791079
@Override
10801080
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
10811081
float[] scores = new float[featureDocs.length];

server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ public boolean isCancelled() {
117117
}
118118

119119
private RankBuilder getRankBuilder(final String field) {
120-
return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE) {
120+
return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE, false) {
121121
@Override
122122
protected void doWriteTo(StreamOutput out) throws IOException {
123123
// no-op
@@ -171,7 +171,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
171171
// no work to be done on the coordinator node for the rank feature phase
172172
@Override
173173
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
174-
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
174+
return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE, false) {
175175
@Override
176176
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
177177
throw new AssertionError("not expected");

test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public static TestRankBuilder randomRankBuilder() {
5252
}
5353

5454
public TestRankBuilder(int windowSize) {
55-
super(windowSize);
55+
super(windowSize, false);
5656
}
5757

5858
public TestRankBuilder(StreamInput in) throws IOException {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public class RandomRankBuilder extends RankBuilder {
6161
private final Integer seed;
6262

6363
public RandomRankBuilder(int rankWindowSize, String field, Integer seed) {
64-
super(rankWindowSize);
64+
super(rankWindowSize, false);
6565

6666
if (field == null || field.isEmpty()) {
6767
throw new IllegalArgumentException("field is required");

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
1212
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
1313

14-
import java.util.Arrays;
15-
import java.util.Comparator;
1614
import java.util.Random;
1715

1816
/**
@@ -24,7 +22,7 @@ public class RandomRankFeaturePhaseRankCoordinatorContext extends RankFeaturePha
2422
private final Integer seed;
2523

2624
public RandomRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, Integer seed) {
27-
super(size, from, rankWindowSize);
25+
super(size, from, rankWindowSize, false);
2826
this.seed = seed;
2927
}
3028

@@ -40,16 +38,4 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
4038
}
4139
scoreListener.onResponse(scores);
4240
}
43-
44-
/**
45-
* Sorts documents by score descending.
46-
* @param originalDocs documents to process
47-
*/
48-
@Override
49-
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
50-
return Arrays.stream(originalDocs)
51-
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
52-
.toArray(RankFeatureDoc[]::new);
53-
}
54-
5541
}

0 commit comments

Comments
 (0)