Skip to content

Commit 52be52b

Browse files
authored
Add sub_searches to the search endpoint (#96224)
This change adds a new top-level element to the search endpoint called sub_searches. This top-level element allows for a list of additional searches where each "sub search" will have a query executed separately as part of ranking and later combined into a final single set of documents based on the ranking algorithm.
1 parent dd1d157 commit 52be52b

File tree

27 files changed

+2989
-691
lines changed

27 files changed

+2989
-691
lines changed

docs/changelog/96224.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
pr: 96224
2+
summary: Add multiple queries for ranking to the search endpoint
3+
area: Ranking
4+
type: enhancement
5+
issues: []
6+
highlight:
7+
title: Add multiple queries for ranking to the search endpoint
8+
body: "The search endpoint adds a new top-level element called `sub_searches`. \
9+
This top-level element is a list of searches used for ranking where each \
10+
\"sub search\" is executed independently. The `sub_searches` element is \
11+
used instead of `query` to allow a search using ranking to execute \
12+
multiple queries. The `sub_searches` and `query` elements cannot be used \
13+
together."
14+
notable: true

modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsTransportActionTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ public void testGetTimeSeriesDataStream() {
199199
);
200200
}
201201

202+
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/96672")
202203
public void testGetTimeSeriesMixedDataStream() {
203204
Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS);
204205
String dataStream1 = "ds-1";

server/src/main/java/org/elasticsearch/TransportVersion.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ private static TransportVersion registerTransportVersion(int id, String uniqueId
136136
public static final TransportVersion V_8_500_010 = registerTransportVersion(8_500_010, "9818C628-1EEC-439B-B943-468F61460675");
137137
public static final TransportVersion V_8_500_011 = registerTransportVersion(8_500_011, "2209F28D-B52E-4BC4-9889-E780F291C32E");
138138
public static final TransportVersion V_8_500_012 = registerTransportVersion(8_500_012, "BB6F4AF1-A860-4FD4-A138-8150FFBE0ABD");
139+
public static final TransportVersion V_8_500_013 = registerTransportVersion(8_500_013, "f65b85ac-db5e-4558-a487-a1dde4f6a33a");
139140

140141
private static class CurrentHolder {
141-
private static final TransportVersion CURRENT = findCurrent(V_8_500_012);
142+
private static final TransportVersion CURRENT = findCurrent(V_8_500_013);
142143

143144
// finds the pluggable current version, or uses the given fallback
144145
private static TransportVersion findCurrent(TransportVersion fallback) {

server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

Lines changed: 10 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
package org.elasticsearch.action.search;
99

1010
import org.apache.lucene.search.ScoreDoc;
11-
import org.elasticsearch.index.query.BoolQueryBuilder;
12-
import org.elasticsearch.index.query.QueryBuilder;
1311
import org.elasticsearch.search.SearchPhaseResult;
1412
import org.elasticsearch.search.SearchShardTarget;
1513
import org.elasticsearch.search.builder.SearchSourceBuilder;
14+
import org.elasticsearch.search.builder.SubSearchSourceBuilder;
1615
import org.elasticsearch.search.dfs.AggregatedDfs;
1716
import org.elasticsearch.search.dfs.DfsKnnResults;
1817
import org.elasticsearch.search.dfs.DfsSearchResult;
@@ -138,75 +137,23 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
138137
return request;
139138
}
140139

141-
if (source.rankBuilder() == null) {
142-
// this path will use linear combination if there are
143-
// multiple knn queries to combine all knn queries into
144-
// a single query per shard
140+
List<SubSearchSourceBuilder> subSearchSourceBuilders = new ArrayList<>(source.subSearches());
145141

142+
for (DfsKnnResults dfsKnnResults : knnResults) {
146143
List<ScoreDoc> scoreDocs = new ArrayList<>();
147-
for (DfsKnnResults dfsKnnResults : knnResults) {
148-
for (ScoreDoc scoreDoc : dfsKnnResults.scoreDocs()) {
149-
if (scoreDoc.shardIndex == request.shardRequestIndex()) {
150-
scoreDocs.add(scoreDoc);
151-
}
144+
for (ScoreDoc scoreDoc : dfsKnnResults.scoreDocs()) {
145+
if (scoreDoc.shardIndex == request.shardRequestIndex()) {
146+
scoreDocs.add(scoreDoc);
152147
}
153148
}
154149
scoreDocs.sort(Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
155-
// It is possible that the different results refer to the same doc.
156-
for (int i = 0; i < scoreDocs.size() - 1; i++) {
157-
ScoreDoc scoreDoc = scoreDocs.get(i);
158-
int j = i + 1;
159-
for (; j < scoreDocs.size(); j++) {
160-
ScoreDoc otherScoreDoc = scoreDocs.get(j);
161-
if (otherScoreDoc.doc != scoreDoc.doc) {
162-
break;
163-
}
164-
scoreDoc.score += otherScoreDoc.score;
165-
}
166-
if (j > i + 1) {
167-
scoreDocs.subList(i + 1, j).clear();
168-
}
169-
}
170-
171150
KnnScoreDocQueryBuilder knnQuery = new KnnScoreDocQueryBuilder(scoreDocs.toArray(new ScoreDoc[0]));
172-
SearchSourceBuilder newSource = source.shallowCopy().knnSearch(List.of());
173-
if (source.query() == null) {
174-
newSource.query(knnQuery);
175-
} else {
176-
newSource.query(new BoolQueryBuilder().should(knnQuery).should(source.query()));
177-
}
178-
request.source(newSource);
179-
} else {
180-
// this path will keep knn queries separate for ranking per shard
181-
// if there are multiple knn queries
182-
183-
List<QueryBuilder> rankQueryBuilders = new ArrayList<>();
184-
if (source.query() != null) {
185-
rankQueryBuilders.add(source.query());
186-
}
187-
188-
for (DfsKnnResults dfsKnnResults : knnResults) {
189-
List<ScoreDoc> scoreDocs = new ArrayList<>();
190-
for (ScoreDoc scoreDoc : dfsKnnResults.scoreDocs()) {
191-
if (scoreDoc.shardIndex == request.shardRequestIndex()) {
192-
scoreDocs.add(scoreDoc);
193-
}
194-
}
195-
scoreDocs.sort(Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
196-
KnnScoreDocQueryBuilder knnQuery = new KnnScoreDocQueryBuilder(scoreDocs.toArray(new ScoreDoc[0]));
197-
rankQueryBuilders.add(knnQuery);
198-
}
199-
200-
BoolQueryBuilder searchQuery = new BoolQueryBuilder();
201-
for (QueryBuilder queryBuilder : rankQueryBuilders) {
202-
searchQuery.should(queryBuilder);
203-
}
204-
205-
SearchSourceBuilder newSource = source.shallowCopy().query(searchQuery).knnSearch(List.of());
206-
request.source(newSource);
207-
request.rankQueryBuilders(rankQueryBuilders);
151+
subSearchSourceBuilders.add(new SubSearchSourceBuilder(knnQuery));
208152
}
209153

154+
source = source.shallowCopy().subSearches(subSearchSourceBuilders).knnSearch(List.of());
155+
request.source(source);
156+
210157
return request;
211158
}
212159
}

server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ private void innerRun() throws Exception {
104104
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = resultConsumer.reduce();
105105
// Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might
106106
// still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase.
107-
final boolean queryAndFetchOptimization = queryResults.length() == 1 && context.getRequest().hasKnnSearch() == false;
107+
final boolean queryAndFetchOptimization = queryResults.length() == 1
108+
&& context.getRequest().hasKnnSearch() == false
109+
&& reducedQueryPhase.rankCoordinatorContext() == null;
108110
final Runnable finishPhase = () -> moveToNextPhase(
109111
queryResults,
110112
reducedQueryPhase,

server/src/main/java/org/elasticsearch/action/search/SearchRequest.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,9 @@ public ActionRequestValidationException validate() {
359359
}
360360
}
361361
if (source != null) {
362+
if (source.subSearches().size() >= 2 && source.rankBuilder() == null) {
363+
validationException = addValidationError("[sub_searches] requires [rank]", validationException);
364+
}
362365
if (source.aggregations() != null) {
363366
validationException = source.aggregations().validate(validationException);
364367
}
@@ -378,10 +381,10 @@ public ActionRequestValidationException validate() {
378381
validationException
379382
);
380383
}
381-
if (source.knnSearch().isEmpty() || source.query() == null && source.knnSearch().size() < 2) {
384+
int queryCount = source.subSearches().size() + source.knnSearch().size();
385+
if (queryCount < 2) {
382386
validationException = addValidationError(
383-
"[rank] requires a minimum of [2] result sets which"
384-
+ " either needs at minimum [a query and a knn search] or [multiple knn searches]",
387+
"[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches",
385388
validationException
386389
);
387390
}

server/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
2121
import org.elasticsearch.search.builder.PointInTimeBuilder;
2222
import org.elasticsearch.search.builder.SearchSourceBuilder;
23+
import org.elasticsearch.search.builder.SubSearchSourceBuilder;
2324
import org.elasticsearch.search.collapse.CollapseBuilder;
2425
import org.elasticsearch.search.fetch.subphase.FieldAndFormat;
2526
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
@@ -165,6 +166,14 @@ public SearchRequestBuilder setQuery(QueryBuilder queryBuilder) {
165166
return this;
166167
}
167168

169+
/**
170+
* Constructs a new search source builder with a list of sub searches.
171+
*/
172+
public SearchRequestBuilder setSubSearches(List<SubSearchSourceBuilder> subSearches) {
173+
sourceBuilder().subSearches(subSearches);
174+
return this;
175+
}
176+
168177
/**
169178
* Sets a filter that will be executed after the query has been executed and only has affect on the search hits
170179
* (not aggregations). This filter is always executed as last filtering mechanism.

server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ public void sendExecuteQuery(
236236
) {
237237
// we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request
238238
// this used to be the QUERY_AND_FETCH which doesn't exist anymore.
239-
final boolean fetchDocuments = request.numberOfShards() == 1;
239+
final boolean fetchDocuments = request.numberOfShards() == 1
240+
&& (request.source() == null || request.source().rankBuilder() == null);
240241
Writeable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : in -> new QuerySearchResult(in, true);
241242

242243
final ActionListener<? super SearchPhaseResult> handler = responseWrapper.apply(connection, listener);

server/src/main/java/org/elasticsearch/search/SearchService.java

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
import org.elasticsearch.search.aggregations.support.AggregationContext;
7878
import org.elasticsearch.search.aggregations.support.AggregationContext.ProductionAggregationContext;
7979
import org.elasticsearch.search.builder.SearchSourceBuilder;
80+
import org.elasticsearch.search.builder.SubSearchSourceBuilder;
8081
import org.elasticsearch.search.collapse.CollapseContext;
8182
import org.elasticsearch.search.dfs.DfsPhase;
8283
import org.elasticsearch.search.dfs.DfsSearchResult;
@@ -626,7 +627,7 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh
626627
} finally {
627628
tracer.stopTrace();
628629
}
629-
if (request.numberOfShards() == 1) {
630+
if (request.numberOfShards() == 1 && (request.source() == null || request.source().rankBuilder() == null)) {
630631
// we already have query results, but we can run fetch at the same time
631632
context.addFetchResult();
632633
return executeFetchPhase(readerContext, context, afterQueryTime);
@@ -992,13 +993,6 @@ protected SearchContext createContext(
992993
if (context.size() == -1) {
993994
context.size(DEFAULT_SIZE);
994995
}
995-
if (request.rankQueryBuilders().isEmpty() == false) {
996-
List<Query> rankQueries = new ArrayList<>();
997-
for (QueryBuilder queryBuilder : request.rankQueryBuilders()) {
998-
rankQueries.add(queryBuilder.toQuery(context.getSearchExecutionContext()));
999-
}
1000-
context.rankShardContext(request.source().rankBuilder().buildRankShardContext(rankQueries, context.from()));
1001-
}
1002996
context.setTask(task);
1003997

1004998
context.preProcess();
@@ -1166,7 +1160,7 @@ private void processFailure(ReaderContext context, Exception exc) {
11661160
}
11671161
}
11681162

1169-
private void parseSource(DefaultSearchContext context, SearchSourceBuilder source, boolean includeAggregations) {
1163+
private void parseSource(DefaultSearchContext context, SearchSourceBuilder source, boolean includeAggregations) throws IOException {
11701164
// nothing to parse...
11711165
if (source == null) {
11721166
return;
@@ -1176,9 +1170,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
11761170
context.from(source.from());
11771171
context.size(source.size());
11781172
Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
1179-
if (source.query() != null) {
1180-
InnerHitContextBuilder.extractInnerHits(source.query(), innerHitBuilders);
1181-
context.parsedQuery(searchExecutionContext.toQuery(source.query()));
1173+
QueryBuilder query = source.query();
1174+
if (query != null) {
1175+
InnerHitContextBuilder.extractInnerHits(query, innerHitBuilders);
1176+
context.parsedQuery(searchExecutionContext.toQuery(query));
11821177
}
11831178
if (source.postFilter() != null) {
11841179
InnerHitContextBuilder.extractInnerHits(source.postFilter(), innerHitBuilders);
@@ -1374,6 +1369,14 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
13741369
final CollapseContext collapseContext = source.collapse().build(searchExecutionContext);
13751370
context.collapse(collapseContext);
13761371
}
1372+
1373+
if (source.rankBuilder() != null) {
1374+
List<Query> queries = new ArrayList<>();
1375+
for (SubSearchSourceBuilder subSearchSourceBuilder : source.subSearches()) {
1376+
queries.add(subSearchSourceBuilder.toSearchQuery(context.getSearchExecutionContext()));
1377+
}
1378+
context.rankShardContext(source.rankBuilder().buildRankShardContext(queries, context.from()));
1379+
}
13771380
}
13781381

13791382
/**
@@ -1623,14 +1626,12 @@ private static boolean canMatchAfterRewrite(final ShardSearchRequest request, fi
16231626
@SuppressWarnings("unchecked")
16241627
public static boolean queryStillMatchesAfterRewrite(ShardSearchRequest request, QueryRewriteContext context) throws IOException {
16251628
Rewriteable.rewrite(request.getRewriteable(), context, false);
1626-
final boolean aliasFilterCanMatch = request.getAliasFilter().getQueryBuilder() instanceof MatchNoneQueryBuilder == false;
1627-
final boolean canMatch;
1629+
boolean canMatch = request.getAliasFilter().getQueryBuilder() instanceof MatchNoneQueryBuilder == false;
16281630
if (canRewriteToMatchNone(request.source())) {
1629-
QueryBuilder queryBuilder = request.source().query();
1630-
canMatch = aliasFilterCanMatch && queryBuilder instanceof MatchNoneQueryBuilder == false;
1631-
} else {
1632-
// null query means match_all
1633-
canMatch = aliasFilterCanMatch;
1631+
canMatch &= request.source()
1632+
.subSearches()
1633+
.stream()
1634+
.anyMatch(sqwb -> sqwb.getQueryBuilder() instanceof MatchNoneQueryBuilder == false);
16341635
}
16351636
return canMatch;
16361637
}
@@ -1641,7 +1642,11 @@ public static boolean queryStillMatchesAfterRewrite(ShardSearchRequest request,
16411642
* a global aggregation is part of this request or if there is a suggest builder present.
16421643
*/
16431644
public static boolean canRewriteToMatchNone(SearchSourceBuilder source) {
1644-
if (source == null || source.query() == null || source.query() instanceof MatchAllQueryBuilder || source.suggest() != null) {
1645+
if (source == null || source.suggest() != null) {
1646+
return false;
1647+
}
1648+
if (source.subSearches().isEmpty()
1649+
|| source.subSearches().stream().anyMatch(sqwb -> sqwb.getQueryBuilder() instanceof MatchAllQueryBuilder)) {
16451650
return false;
16461651
}
16471652
AggregatorFactories.Builder aggregations = source.aggregations();

0 commit comments

Comments
 (0)