Skip to content

Commit d7474e6

Browse files
Move scroll + dfs reduce code from SearchPhaseController to actual users (#119726)
No need to have this logic live in `SearchPhaseController` when it only has a single callsite elsewhere.
1 parent 430c9fa commit d7474e6

File tree

5 files changed

+156
-137
lines changed

5 files changed

+156
-137
lines changed

server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,30 @@
1010
package org.elasticsearch.action.search;
1111

1212
import org.apache.logging.log4j.Logger;
13+
import org.apache.lucene.index.Term;
14+
import org.apache.lucene.search.CollectionStatistics;
15+
import org.apache.lucene.search.ScoreDoc;
16+
import org.apache.lucene.search.TermStatistics;
17+
import org.apache.lucene.search.TopDocs;
18+
import org.apache.lucene.search.TotalHits;
19+
import org.apache.lucene.util.SetOnce;
1320
import org.elasticsearch.action.ActionListener;
1421
import org.elasticsearch.client.internal.Client;
1522
import org.elasticsearch.cluster.ClusterState;
1623
import org.elasticsearch.cluster.routing.GroupShardsIterator;
1724
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1825
import org.elasticsearch.search.SearchPhaseResult;
1926
import org.elasticsearch.search.SearchShardTarget;
27+
import org.elasticsearch.search.builder.SearchSourceBuilder;
2028
import org.elasticsearch.search.dfs.AggregatedDfs;
2129
import org.elasticsearch.search.dfs.DfsKnnResults;
2230
import org.elasticsearch.search.dfs.DfsSearchResult;
2331
import org.elasticsearch.search.internal.AliasFilter;
2432
import org.elasticsearch.transport.Transport;
2533

34+
import java.util.ArrayList;
35+
import java.util.Collection;
36+
import java.util.HashMap;
2637
import java.util.List;
2738
import java.util.Map;
2839
import java.util.concurrent.Executor;
@@ -93,12 +104,11 @@ protected void executePhaseOnShard(
93104
@Override
94105
protected SearchPhase getNextPhase() {
95106
final List<DfsSearchResult> dfsSearchResults = results.getAtomicArray().asList();
96-
final AggregatedDfs aggregatedDfs = SearchPhaseController.aggregateDfs(dfsSearchResults);
97-
final List<DfsKnnResults> mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults);
107+
final AggregatedDfs aggregatedDfs = aggregateDfs(dfsSearchResults);
98108
return new DfsQueryPhase(
99109
dfsSearchResults,
100110
aggregatedDfs,
101-
mergedKnnResults,
111+
mergeKnnResults(getRequest(), dfsSearchResults),
102112
queryPhaseResultConsumer,
103113
(queryResults) -> SearchQueryThenFetchAsyncAction.nextPhase(client, this, queryResults, aggregatedDfs),
104114
this
@@ -109,4 +119,95 @@ protected SearchPhase getNextPhase() {
109119
protected void onShardGroupFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
110120
progressListener.notifyQueryFailure(shardIndex, shardTarget, exc);
111121
}
122+
123+
private static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
124+
if (request.hasKnnSearch() == false) {
125+
return null;
126+
}
127+
SearchSourceBuilder source = request.source();
128+
List<List<TopDocs>> topDocsLists = new ArrayList<>(source.knnSearch().size());
129+
List<SetOnce<String>> nestedPath = new ArrayList<>(source.knnSearch().size());
130+
for (int i = 0; i < source.knnSearch().size(); i++) {
131+
topDocsLists.add(new ArrayList<>());
132+
nestedPath.add(new SetOnce<>());
133+
}
134+
135+
for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
136+
if (dfsSearchResult.knnResults() != null) {
137+
for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
138+
DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
139+
ScoreDoc[] scoreDocs = knnResults.scoreDocs();
140+
TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO);
141+
TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
142+
SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
143+
topDocsLists.get(i).add(shardTopDocs);
144+
nestedPath.get(i).trySet(knnResults.getNestedPath());
145+
}
146+
}
147+
}
148+
149+
List<DfsKnnResults> mergedResults = new ArrayList<>(source.knnSearch().size());
150+
for (int i = 0; i < source.knnSearch().size(); i++) {
151+
TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
152+
mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs));
153+
}
154+
return mergedResults;
155+
}
156+
157+
private static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
158+
Map<Term, TermStatistics> termStatistics = new HashMap<>();
159+
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
160+
long aggMaxDoc = 0;
161+
for (DfsSearchResult lEntry : results) {
162+
final Term[] terms = lEntry.terms();
163+
final TermStatistics[] stats = lEntry.termStatistics();
164+
assert terms.length == stats.length;
165+
for (int i = 0; i < terms.length; i++) {
166+
assert terms[i] != null;
167+
if (stats[i] == null) {
168+
continue;
169+
}
170+
TermStatistics existing = termStatistics.get(terms[i]);
171+
if (existing != null) {
172+
assert terms[i].bytes().equals(existing.term());
173+
termStatistics.put(
174+
terms[i],
175+
new TermStatistics(
176+
existing.term(),
177+
existing.docFreq() + stats[i].docFreq(),
178+
existing.totalTermFreq() + stats[i].totalTermFreq()
179+
)
180+
);
181+
} else {
182+
termStatistics.put(terms[i], stats[i]);
183+
}
184+
185+
}
186+
187+
assert lEntry.fieldStatistics().containsKey(null) == false;
188+
for (var entry : lEntry.fieldStatistics().entrySet()) {
189+
String key = entry.getKey();
190+
CollectionStatistics value = entry.getValue();
191+
if (value == null) {
192+
continue;
193+
}
194+
assert key != null;
195+
CollectionStatistics existing = fieldStatistics.get(key);
196+
if (existing != null) {
197+
CollectionStatistics merged = new CollectionStatistics(
198+
key,
199+
existing.maxDoc() + value.maxDoc(),
200+
existing.docCount() + value.docCount(),
201+
existing.sumTotalTermFreq() + value.sumTotalTermFreq(),
202+
existing.sumDocFreq() + value.sumDocFreq()
203+
);
204+
fieldStatistics.put(key, merged);
205+
} else {
206+
fieldStatistics.put(key, value);
207+
}
208+
}
209+
aggMaxDoc += lEntry.maxDoc();
210+
}
211+
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
212+
}
112213
}

server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

Lines changed: 0 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,16 @@
99

1010
package org.elasticsearch.action.search;
1111

12-
import org.apache.lucene.index.Term;
13-
import org.apache.lucene.search.CollectionStatistics;
1412
import org.apache.lucene.search.FieldDoc;
1513
import org.apache.lucene.search.ScoreDoc;
1614
import org.apache.lucene.search.Sort;
1715
import org.apache.lucene.search.SortField;
1816
import org.apache.lucene.search.SortedNumericSortField;
1917
import org.apache.lucene.search.SortedSetSortField;
20-
import org.apache.lucene.search.TermStatistics;
2118
import org.apache.lucene.search.TopDocs;
2219
import org.apache.lucene.search.TopFieldDocs;
2320
import org.apache.lucene.search.TotalHits;
2421
import org.apache.lucene.search.TotalHits.Relation;
25-
import org.apache.lucene.util.SetOnce;
2622
import org.elasticsearch.common.breaker.CircuitBreaker;
2723
import org.elasticsearch.common.io.stream.DelayableWriteable;
2824
import org.elasticsearch.common.lucene.Lucene;
@@ -42,9 +38,6 @@
4238
import org.elasticsearch.search.aggregations.AggregatorFactories;
4339
import org.elasticsearch.search.aggregations.InternalAggregations;
4440
import org.elasticsearch.search.builder.SearchSourceBuilder;
45-
import org.elasticsearch.search.dfs.AggregatedDfs;
46-
import org.elasticsearch.search.dfs.DfsKnnResults;
47-
import org.elasticsearch.search.dfs.DfsSearchResult;
4841
import org.elasticsearch.search.fetch.FetchSearchResult;
4942
import org.elasticsearch.search.internal.SearchContext;
5043
import org.elasticsearch.search.profile.SearchProfileQueryPhaseResult;
@@ -84,97 +77,6 @@ public SearchPhaseController(
8477
this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder;
8578
}
8679

87-
public static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
88-
Map<Term, TermStatistics> termStatistics = new HashMap<>();
89-
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
90-
long aggMaxDoc = 0;
91-
for (DfsSearchResult lEntry : results) {
92-
final Term[] terms = lEntry.terms();
93-
final TermStatistics[] stats = lEntry.termStatistics();
94-
assert terms.length == stats.length;
95-
for (int i = 0; i < terms.length; i++) {
96-
assert terms[i] != null;
97-
if (stats[i] == null) {
98-
continue;
99-
}
100-
TermStatistics existing = termStatistics.get(terms[i]);
101-
if (existing != null) {
102-
assert terms[i].bytes().equals(existing.term());
103-
termStatistics.put(
104-
terms[i],
105-
new TermStatistics(
106-
existing.term(),
107-
existing.docFreq() + stats[i].docFreq(),
108-
existing.totalTermFreq() + stats[i].totalTermFreq()
109-
)
110-
);
111-
} else {
112-
termStatistics.put(terms[i], stats[i]);
113-
}
114-
115-
}
116-
117-
assert lEntry.fieldStatistics().containsKey(null) == false;
118-
for (var entry : lEntry.fieldStatistics().entrySet()) {
119-
String key = entry.getKey();
120-
CollectionStatistics value = entry.getValue();
121-
if (value == null) {
122-
continue;
123-
}
124-
assert key != null;
125-
CollectionStatistics existing = fieldStatistics.get(key);
126-
if (existing != null) {
127-
CollectionStatistics merged = new CollectionStatistics(
128-
key,
129-
existing.maxDoc() + value.maxDoc(),
130-
existing.docCount() + value.docCount(),
131-
existing.sumTotalTermFreq() + value.sumTotalTermFreq(),
132-
existing.sumDocFreq() + value.sumDocFreq()
133-
);
134-
fieldStatistics.put(key, merged);
135-
} else {
136-
fieldStatistics.put(key, value);
137-
}
138-
}
139-
aggMaxDoc += lEntry.maxDoc();
140-
}
141-
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
142-
}
143-
144-
public static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
145-
if (request.hasKnnSearch() == false) {
146-
return null;
147-
}
148-
SearchSourceBuilder source = request.source();
149-
List<List<TopDocs>> topDocsLists = new ArrayList<>(source.knnSearch().size());
150-
List<SetOnce<String>> nestedPath = new ArrayList<>(source.knnSearch().size());
151-
for (int i = 0; i < source.knnSearch().size(); i++) {
152-
topDocsLists.add(new ArrayList<>());
153-
nestedPath.add(new SetOnce<>());
154-
}
155-
156-
for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
157-
if (dfsSearchResult.knnResults() != null) {
158-
for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
159-
DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
160-
ScoreDoc[] scoreDocs = knnResults.scoreDocs();
161-
TotalHits totalHits = new TotalHits(scoreDocs.length, Relation.EQUAL_TO);
162-
TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
163-
setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
164-
topDocsLists.get(i).add(shardTopDocs);
165-
nestedPath.get(i).trySet(knnResults.getNestedPath());
166-
}
167-
}
168-
}
169-
170-
List<DfsKnnResults> mergedResults = new ArrayList<>(source.knnSearch().size());
171-
for (int i = 0; i < source.knnSearch().size(); i++) {
172-
TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
173-
mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs));
174-
}
175-
return mergedResults;
176-
}
177-
17880
/**
17981
* Returns a score doc array of top N search docs across all shards, followed by top suggest docs for each
18082
* named completion suggestion across all shards. If more than one named completion suggestion is specified in the
@@ -496,38 +398,6 @@ private static SearchHits getHits(
496398
);
497399
}
498400

499-
/**
500-
* Reduces the given query results and consumes all aggregations and profile results.
501-
* @param queryResults a list of non-null query shard results
502-
*/
503-
static ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
504-
AggregationReduceContext.Builder aggReduceContextBuilder = new AggregationReduceContext.Builder() {
505-
@Override
506-
public AggregationReduceContext forPartialReduction() {
507-
throw new UnsupportedOperationException("Scroll requests don't have aggs");
508-
}
509-
510-
@Override
511-
public AggregationReduceContext forFinalReduction() {
512-
throw new UnsupportedOperationException("Scroll requests don't have aggs");
513-
}
514-
};
515-
final TopDocsStats topDocsStats = new TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
516-
final List<TopDocs> topDocs = new ArrayList<>();
517-
for (SearchPhaseResult sortedResult : queryResults) {
518-
QuerySearchResult queryResult = sortedResult.queryResult();
519-
final TopDocsAndMaxScore td = queryResult.consumeTopDocs();
520-
assert td != null;
521-
topDocsStats.add(td, queryResult.searchTimedOut(), queryResult.terminatedEarly());
522-
// make sure we set the shard index before we add it - the consumer didn't do that yet
523-
if (td.topDocs.scoreDocs.length > 0) {
524-
setShardIndex(td.topDocs, queryResult.getShardIndex());
525-
topDocs.add(td.topDocs);
526-
}
527-
}
528-
return reducedQueryPhase(queryResults, null, topDocs, topDocsStats, 0, true, aggReduceContextBuilder, null, true);
529-
}
530-
531401
/**
532402
* Reduces the given query results and consumes all aggregations and profile results.
533403
* @param queryResults a list of non-null query shard results

server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,27 @@
1010
package org.elasticsearch.action.search;
1111

1212
import org.apache.logging.log4j.Logger;
13+
import org.apache.lucene.search.TopDocs;
1314
import org.elasticsearch.action.ActionListener;
1415
import org.elasticsearch.cluster.node.DiscoveryNode;
1516
import org.elasticsearch.cluster.node.DiscoveryNodes;
17+
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
1618
import org.elasticsearch.common.util.concurrent.AtomicArray;
1719
import org.elasticsearch.common.util.concurrent.CountDown;
1820
import org.elasticsearch.core.Nullable;
1921
import org.elasticsearch.search.SearchPhaseResult;
2022
import org.elasticsearch.search.SearchShardTarget;
23+
import org.elasticsearch.search.aggregations.AggregationReduceContext;
2124
import org.elasticsearch.search.internal.InternalScrollSearchRequest;
25+
import org.elasticsearch.search.internal.SearchContext;
2226
import org.elasticsearch.search.internal.ShardSearchContextId;
27+
import org.elasticsearch.search.query.QuerySearchResult;
2328
import org.elasticsearch.transport.RemoteClusterService;
2429
import org.elasticsearch.transport.Transport;
2530

2631
import java.util.ArrayList;
2732
import java.util.Arrays;
33+
import java.util.Collection;
2834
import java.util.HashSet;
2935
import java.util.List;
3036
import java.util.Set;
@@ -301,4 +307,48 @@ protected void onShardFailure(
301307
protected Transport.Connection getConnection(String clusterAlias, DiscoveryNode node) {
302308
return searchTransportService.getConnection(clusterAlias, node);
303309
}
310+
311+
/**
312+
* Reduces the given query results and consumes all aggregations and profile results.
313+
* @param queryResults a list of non-null query shard results
314+
*/
315+
protected static SearchPhaseController.ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
316+
AggregationReduceContext.Builder aggReduceContextBuilder = new AggregationReduceContext.Builder() {
317+
@Override
318+
public AggregationReduceContext forPartialReduction() {
319+
throw new UnsupportedOperationException("Scroll requests don't have aggs");
320+
}
321+
322+
@Override
323+
public AggregationReduceContext forFinalReduction() {
324+
throw new UnsupportedOperationException("Scroll requests don't have aggs");
325+
}
326+
};
327+
final SearchPhaseController.TopDocsStats topDocsStats = new SearchPhaseController.TopDocsStats(
328+
SearchContext.TRACK_TOTAL_HITS_ACCURATE
329+
);
330+
final List<TopDocs> topDocs = new ArrayList<>();
331+
for (SearchPhaseResult sortedResult : queryResults) {
332+
QuerySearchResult queryResult = sortedResult.queryResult();
333+
final TopDocsAndMaxScore td = queryResult.consumeTopDocs();
334+
assert td != null;
335+
topDocsStats.add(td, queryResult.searchTimedOut(), queryResult.terminatedEarly());
336+
// make sure we set the shard index before we add it - the consumer didn't do that yet
337+
if (td.topDocs.scoreDocs.length > 0) {
338+
SearchPhaseController.setShardIndex(td.topDocs, queryResult.getShardIndex());
339+
topDocs.add(td.topDocs);
340+
}
341+
}
342+
return SearchPhaseController.reducedQueryPhase(
343+
queryResults,
344+
null,
345+
topDocs,
346+
topDocsStats,
347+
0,
348+
true,
349+
aggReduceContextBuilder,
350+
null,
351+
true
352+
);
353+
}
304354
}

server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryAndFetchAsyncAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ protected void executeInitialPhase(
5151

5252
@Override
5353
protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
54-
return sendResponsePhase(SearchPhaseController.reducedScrollQueryPhase(queryFetchResults.asList()), queryFetchResults);
54+
return sendResponsePhase(reducedScrollQueryPhase(queryFetchResults.asList()), queryFetchResults);
5555
}
5656

5757
@Override

0 commit comments

Comments
 (0)