Skip to content

Commit 7794bef

Browse files
authored
Transforming rank rrf to the corresponding retriever (#115026)
1 parent 3ae7921 commit 7794bef

File tree

10 files changed

+181
-104
lines changed

10 files changed

+181
-104
lines changed

server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@ private SearchCapabilities() {}
2424
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
2525
/** Support Byte and Float with Bit dot product. */
2626
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product";
27+
/** Support transforming rank rrf queries to the corresponding rrf retriever. */
28+
private static final String TRANSFORM_RANK_RRF_TO_RETRIEVER = "transform_rank_rrf_to_retriever";
2729

2830
public static final Set<String> CAPABILITIES = Set.of(
2931
RANGE_REGEX_INTERVAL_QUERY_CAPABILITY,
3032
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY,
31-
BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY
33+
BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY,
34+
TRANSFORM_RANK_RRF_TO_RETRIEVER
3235
);
3336
}

server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,18 @@ private SearchSourceBuilder parseXContent(
16381638
}
16391639

16401640
knnSearch = knnBuilders.stream().map(knnBuilder -> knnBuilder.build(size())).collect(Collectors.toList());
1641+
if (rankBuilder != null) {
1642+
if (retrieverBuilder != null) {
1643+
throw new IllegalArgumentException("Cannot specify both [rank] and [retriever].");
1644+
}
1645+
RetrieverBuilder transformedRetriever = rankBuilder.toRetriever(this, clusterSupportsFeature);
1646+
if (transformedRetriever != null) {
1647+
this.retriever(transformedRetriever);
1648+
rankBuilder = null;
1649+
subSearchSourceBuilders.clear();
1650+
knnSearch.clear();
1651+
}
1652+
}
16411653
searchUsageConsumer.accept(searchUsage);
16421654
return this;
16431655
}

server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,24 @@
1616
import org.elasticsearch.common.io.stream.StreamInput;
1717
import org.elasticsearch.common.io.stream.StreamOutput;
1818
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
19+
import org.elasticsearch.core.Nullable;
20+
import org.elasticsearch.core.UpdateForV10;
21+
import org.elasticsearch.features.NodeFeature;
1922
import org.elasticsearch.search.SearchService;
23+
import org.elasticsearch.search.builder.SearchSourceBuilder;
2024
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
2125
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
2226
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
2327
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
28+
import org.elasticsearch.search.retriever.RetrieverBuilder;
2429
import org.elasticsearch.xcontent.ParseField;
2530
import org.elasticsearch.xcontent.ToXContentObject;
2631
import org.elasticsearch.xcontent.XContentBuilder;
2732

2833
import java.io.IOException;
2934
import java.util.List;
3035
import java.util.Objects;
36+
import java.util.function.Predicate;
3137

3238
/**
3339
* {@code RankBuilder} is used as a base class to manage input, parsing, and subsequent generation of appropriate contexts
@@ -109,6 +115,16 @@ public int rankWindowSize() {
109115
*/
110116
public abstract RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client);
111117

118+
/**
119+
* Transforms the specific rank builder (as parsed through SearchSourceBuilder) to the corresponding retriever.
120+
* This is used to ensure smooth deprecation of `rank` and `sub_searches` and move towards the retriever framework
121+
*/
122+
@UpdateForV10(owner = UpdateForV10.Owner.SEARCH_RELEVANCE) // remove for 10.0 once we remove support for the rank parameter in SearchAPI
123+
@Nullable
124+
public RetrieverBuilder toRetriever(SearchSourceBuilder searchSourceBuilder, Predicate<NodeFeature> clusterSupportsFeature) {
125+
return null;
126+
}
127+
112128
@Override
113129
public final boolean equals(Object obj) {
114130
if (this == obj) {

server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ public int k() {
286286
return k;
287287
}
288288

289+
public int getNumCands() {
290+
return numCands;
291+
}
292+
289293
public QueryVectorBuilder getQueryVectorBuilder() {
290294
return queryVectorBuilder;
291295
}

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,32 @@
1414
import org.elasticsearch.client.internal.Client;
1515
import org.elasticsearch.common.io.stream.StreamInput;
1616
import org.elasticsearch.common.io.stream.StreamOutput;
17+
import org.elasticsearch.features.NodeFeature;
1718
import org.elasticsearch.license.LicenseUtils;
19+
import org.elasticsearch.search.builder.SearchSourceBuilder;
1820
import org.elasticsearch.search.rank.RankBuilder;
1921
import org.elasticsearch.search.rank.RankDoc;
2022
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
2123
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
2224
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
2325
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
26+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
27+
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
28+
import org.elasticsearch.search.retriever.RetrieverBuilder;
29+
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
30+
import org.elasticsearch.search.vectors.KnnSearchBuilder;
2431
import org.elasticsearch.xcontent.ConstructingObjectParser;
2532
import org.elasticsearch.xcontent.ParseField;
2633
import org.elasticsearch.xcontent.XContentBuilder;
2734
import org.elasticsearch.xcontent.XContentParser;
2835
import org.elasticsearch.xpack.core.XPackPlugin;
2936

3037
import java.io.IOException;
38+
import java.util.ArrayList;
3139
import java.util.Arrays;
3240
import java.util.List;
3341
import java.util.Objects;
42+
import java.util.function.Predicate;
3443

3544
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
3645

@@ -183,6 +192,37 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
183192
return null;
184193
}
185194

195+
@Override
196+
public RetrieverBuilder toRetriever(SearchSourceBuilder source, Predicate<NodeFeature> clusterSupportsFeature) {
197+
if (false == clusterSupportsFeature.test(RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED)) {
198+
return null;
199+
}
200+
int totalQueries = source.subSearches().size() + source.knnSearch().size();
201+
if (totalQueries < 2) {
202+
throw new IllegalArgumentException("[rrf] requires at least 2 sub-queries to be defined");
203+
}
204+
List<CompoundRetrieverBuilder.RetrieverSource> retrieverSources = new ArrayList<>(totalQueries);
205+
for (int i = 0; i < source.subSearches().size(); i++) {
206+
RetrieverBuilder standardRetriever = new StandardRetrieverBuilder(source.subSearches().get(i).getQueryBuilder());
207+
standardRetriever.retrieverName(source.subSearches().get(i).getQueryBuilder().queryName());
208+
retrieverSources.add(new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null));
209+
}
210+
for (int i = 0; i < source.knnSearch().size(); i++) {
211+
KnnSearchBuilder knnSearchBuilder = source.knnSearch().get(i);
212+
RetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
213+
knnSearchBuilder.getField(),
214+
knnSearchBuilder.getQueryVector().asFloatVector(),
215+
knnSearchBuilder.getQueryVectorBuilder(),
216+
knnSearchBuilder.k(),
217+
knnSearchBuilder.getNumCands(),
218+
knnSearchBuilder.getSimilarity()
219+
);
220+
knnRetriever.retrieverName(knnSearchBuilder.queryName());
221+
retrieverSources.add(new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null));
222+
}
223+
return new RRFRetrieverBuilder(retrieverSources, rankWindowSize(), rankConstant());
224+
}
225+
186226
@Override
187227
protected boolean doEquals(RankBuilder other) {
188228
return Objects.equals(rankConstant, ((RRFRankBuilder) other).rankConstant);

x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/100_rank_rrf.yml

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
setup:
22
- requires:
3-
cluster_features: "gte_v8.8.0"
4-
reason: 'rank added in 8.8'
3+
capabilities:
4+
- method: POST
5+
path: /_search
6+
capabilities: [ transform_rank_rrf_to_retriever ]
7+
test_runner_features: capabilities
8+
reason: "Support for transforming deprecated rank_rrf queries to the corresponding rrf retriever is required"
59
- skip:
610
features: "warnings"
711

@@ -212,7 +216,7 @@ setup:
212216
"RRF rank should fail if size > rank_window_size":
213217

214218
- do:
215-
catch: "/\\[rank\\] requires \\[rank_window_size: 2\\] be greater than or equal to \\[size: 10\\]/"
219+
catch: "/\\[rrf\\] requires \\[rank_window_size: 2\\] be greater than or equal to \\[size: 10\\]/"
216220
search:
217221
index: test
218222
body:
@@ -284,3 +288,22 @@ setup:
284288
rank_window_size: 10
285289
rank_constant: 0.3
286290
size: 10
291+
292+
---
293+
"RRF rank should fail if we specify both rank and retriever":
294+
- do:
295+
catch: "/Cannot specify both \\[rank\\] and \\[retriever\\]./"
296+
search:
297+
index: test
298+
body:
299+
track_total_hits: true
300+
fields: [ "text", "keyword" ]
301+
retriever:
302+
standard:
303+
query:
304+
match_all: {}
305+
rank:
306+
rrf:
307+
rank_window_size: 10
308+
rank_constant: 10
309+
size: 10

0 commit comments

Comments
 (0)