Skip to content

Commit cf034c0

Browse files
authored
Add a new random rerank retriever (#111851)
* Add a new random rerank retriever, that reranks results in random order without requiring inference * Update docs/changelog/111851.yaml * PR feedback - remove null checks for field as it can never be null * Update docs * Revert "Update docs" This reverts commit 3d61676. * Remove minScore * Random seed * Delete docs/changelog/111851.yaml * PR feedback * Add optional seed to request, YAML test * PR feedback
1 parent 8b0a1aa commit cf034c0

File tree

9 files changed

+623
-2
lines changed

9 files changed

+623
-2
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ static TransportVersion def(int id) {
191191
public static final TransportVersion INGEST_PIPELINE_EXCEPTION_ADDED = def(8_721_00_0);
192192
public static final TransportVersion ZDT_NANOS_SUPPORT = def(8_722_00_0);
193193
public static final TransportVersion REMOVE_GLOBAL_RETENTION_FROM_TEMPLATES = def(8_723_00_0);
194+
public static final TransportVersion RANDOM_RERANKER_RETRIEVER = def(8_724_00_0);
194195

195196
/*
196197
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.features.FeatureSpecification;
1111
import org.elasticsearch.features.NodeFeature;
12+
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
1213
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
1314

1415
import java.util.Set;
@@ -20,7 +21,10 @@ public class InferenceFeatures implements FeatureSpecification {
2021

2122
@Override
2223
public Set<NodeFeature> getFeatures() {
23-
return Set.of(TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED);
24+
return Set.of(
25+
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED,
26+
RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED
27+
);
2428
}
2529

2630
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
6464
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
6565
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
66+
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
67+
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
6668
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
6769
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
6870
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -243,6 +245,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
243245
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
244246
var entries = new ArrayList<>(InferenceNamedWriteablesProvider.getNamedWriteables());
245247
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new));
248+
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new));
246249
return entries;
247250
}
248251

@@ -336,7 +339,8 @@ public List<QuerySpec<?>> getQueries() {
336339
@Override
337340
public List<RetrieverSpec<?>> getRetrievers() {
338341
return List.of(
339-
new RetrieverSpec<>(new ParseField(TextSimilarityRankBuilder.NAME), TextSimilarityRankRetrieverBuilder::fromXContent)
342+
new RetrieverSpec<>(new ParseField(TextSimilarityRankBuilder.NAME), TextSimilarityRankRetrieverBuilder::fromXContent),
343+
new RetrieverSpec<>(new ParseField(RandomRankBuilder.NAME), RandomRankRetrieverBuilder::fromXContent)
340344
);
341345
}
342346
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.rank.random;
9+
10+
import org.apache.lucene.search.Explanation;
11+
import org.apache.lucene.search.Query;
12+
import org.elasticsearch.TransportVersion;
13+
import org.elasticsearch.TransportVersions;
14+
import org.elasticsearch.client.internal.Client;
15+
import org.elasticsearch.common.io.stream.StreamInput;
16+
import org.elasticsearch.common.io.stream.StreamOutput;
17+
import org.elasticsearch.search.rank.RankBuilder;
18+
import org.elasticsearch.search.rank.RankDoc;
19+
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
20+
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
21+
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
22+
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
23+
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
24+
import org.elasticsearch.search.rank.rerank.RerankingQueryPhaseRankCoordinatorContext;
25+
import org.elasticsearch.search.rank.rerank.RerankingQueryPhaseRankShardContext;
26+
import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext;
27+
import org.elasticsearch.xcontent.ConstructingObjectParser;
28+
import org.elasticsearch.xcontent.XContentBuilder;
29+
30+
import java.io.IOException;
31+
import java.util.List;
32+
import java.util.Objects;
33+
34+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
35+
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
36+
import static org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder.FIELD_FIELD;
37+
import static org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder.SEED_FIELD;
38+
39+
/**
40+
* A {@code RankBuilder} that performs reranking with random scores, used for testing.
41+
*/
42+
public class RandomRankBuilder extends RankBuilder {
43+
44+
public static final String NAME = "random_reranker";
45+
46+
static final ConstructingObjectParser<RandomRankBuilder, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
47+
Integer rankWindowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (Integer) args[0];
48+
String field = (String) args[1];
49+
Integer seed = (Integer) args[2];
50+
51+
return new RandomRankBuilder(rankWindowSize, field, seed);
52+
});
53+
54+
static {
55+
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
56+
PARSER.declareString(constructorArg(), FIELD_FIELD);
57+
PARSER.declareInt(optionalConstructorArg(), SEED_FIELD);
58+
}
59+
60+
private final String field;
61+
private final Integer seed;
62+
63+
public RandomRankBuilder(int rankWindowSize, String field, Integer seed) {
64+
super(rankWindowSize);
65+
66+
if (field == null || field.isEmpty()) {
67+
throw new IllegalArgumentException("field is required");
68+
}
69+
70+
this.field = field;
71+
this.seed = seed;
72+
}
73+
74+
public RandomRankBuilder(StreamInput in) throws IOException {
75+
super(in);
76+
// rankWindowSize deserialization is handled by the parent class RankBuilder
77+
this.field = in.readString();
78+
this.seed = in.readOptionalInt();
79+
}
80+
81+
@Override
82+
public String getWriteableName() {
83+
return NAME;
84+
}
85+
86+
@Override
87+
public TransportVersion getMinimalSupportedVersion() {
88+
return TransportVersions.RANDOM_RERANKER_RETRIEVER;
89+
}
90+
91+
@Override
92+
public void doWriteTo(StreamOutput out) throws IOException {
93+
// rankWindowSize serialization is handled by the parent class RankBuilder
94+
out.writeString(field);
95+
out.writeOptionalInt(seed);
96+
}
97+
98+
@Override
99+
public void doXContent(XContentBuilder builder, Params params) throws IOException {
100+
// rankWindowSize serialization is handled by the parent class RankBuilder
101+
builder.field(FIELD_FIELD.getPreferredName(), field);
102+
if (seed != null) {
103+
builder.field(SEED_FIELD.getPreferredName(), seed);
104+
}
105+
}
106+
107+
@Override
108+
public boolean isCompoundBuilder() {
109+
return false;
110+
}
111+
112+
@Override
113+
public Explanation explainHit(Explanation baseExplanation, RankDoc scoreDoc, List<String> queryNames) {
114+
if (scoreDoc == null) {
115+
return baseExplanation;
116+
}
117+
if (false == baseExplanation.isMatch()) {
118+
return baseExplanation;
119+
}
120+
121+
assert scoreDoc instanceof RankFeatureDoc : "ScoreDoc is not an instance of RankFeatureDoc";
122+
RankFeatureDoc rankFeatureDoc = (RankFeatureDoc) scoreDoc;
123+
124+
return Explanation.match(
125+
rankFeatureDoc.score,
126+
"rank after reranking: [" + rankFeatureDoc.rank + "] using seed [" + seed + "] with score: [" + rankFeatureDoc.score + "]",
127+
baseExplanation
128+
);
129+
}
130+
131+
@Override
132+
public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
133+
return new RerankingQueryPhaseRankShardContext(queries, rankWindowSize());
134+
}
135+
136+
@Override
137+
public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) {
138+
return new RerankingQueryPhaseRankCoordinatorContext(rankWindowSize());
139+
}
140+
141+
@Override
142+
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
143+
return new RerankingRankFeaturePhaseRankShardContext(field);
144+
}
145+
146+
@Override
147+
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
148+
return new RandomRankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), seed);
149+
}
150+
151+
public String field() {
152+
return field;
153+
}
154+
155+
@Override
156+
protected boolean doEquals(RankBuilder other) {
157+
RandomRankBuilder that = (RandomRankBuilder) other;
158+
return Objects.equals(field, that.field) && Objects.equals(seed, that.seed);
159+
}
160+
161+
@Override
162+
protected int doHashCode() {
163+
return Objects.hash(field, seed);
164+
}
165+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.rank.random;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
12+
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
13+
14+
import java.util.Arrays;
15+
import java.util.Comparator;
16+
import java.util.Random;
17+
18+
/**
19+
* A {@code RankFeaturePhaseRankCoordinatorContext} that performs a rerank inference call to determine relevance scores for documents within
20+
* the provided rank window.
21+
*/
22+
public class RandomRankFeaturePhaseRankCoordinatorContext extends RankFeaturePhaseRankCoordinatorContext {
23+
24+
private final Integer seed;
25+
26+
public RandomRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, Integer seed) {
27+
super(size, from, rankWindowSize);
28+
this.seed = seed;
29+
}
30+
31+
@Override
32+
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
33+
// Generate random scores seeded by doc
34+
float[] scores = new float[featureDocs.length];
35+
for (int i = 0; i < featureDocs.length; i++) {
36+
RankFeatureDoc featureDoc = featureDocs[i];
37+
int doc = featureDoc.doc;
38+
long docSeed = seed != null ? seed + doc : doc;
39+
scores[i] = new Random(docSeed).nextFloat();
40+
}
41+
scoreListener.onResponse(scores);
42+
}
43+
44+
/**
45+
* Sorts documents by score descending.
46+
* @param originalDocs documents to process
47+
*/
48+
@Override
49+
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
50+
return Arrays.stream(originalDocs)
51+
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
52+
.toArray(RankFeatureDoc[]::new);
53+
}
54+
55+
}

0 commit comments

Comments
 (0)