Skip to content

Commit 4c28581

Browse files
[8.x] Remove unnecessary interfaces/abstractions from search phases (#120079) (#126511)
* Remove unnecessary interfaces/abstractions from search phases (#120079) A couple obvious cleanups where declared general interfaces aren't actually used + we shouldn't escape potentially heavy-weight search phases just to get their name in a call so adjusted the failure API. * Use inheritance instead of composition to simplify search phase transitions (#119272) We only need the extensibility for testing and it's a lot easier to reason about the code if we have explicit methods instead of overly complicated composition with lots of redundant references being retained all over the place. -> lets simplify to inheritance and get shorter code that performs more predictably (especially when it comes to memory) as a first step. This also opens up the possibility of further simplifications and removing more retained state/memory as we go through the search phases.
1 parent 17ac046 commit 4c28581

20 files changed

+333
-371
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ public final void start() {
218218
}
219219

220220
@Override
221-
public final void run() {
221+
protected final void run() {
222222
for (final SearchShardIterator iterator : toSkipShardsIts) {
223223
assert iterator.skip();
224224
skipShard(iterator);
@@ -300,7 +300,7 @@ private static boolean assertExecuteOnStartThread() {
300300
return true;
301301
}
302302

303-
protected void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) {
303+
private void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) {
304304
if (throttleConcurrentRequests) {
305305
var pendingExecutions = pendingExecutionsPerNode.computeIfAbsent(
306306
shard.getNodeId(),
@@ -363,7 +363,7 @@ protected abstract void executePhaseOnShard(
363363
* of the next phase. If there are no successful operations in the context when this method is executed the search is aborted and
364364
* a response is returned to the user indicating that all shards have failed.
365365
*/
366-
protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase> nextPhaseSupplier) {
366+
protected void executeNextPhase(String currentPhase, Supplier<SearchPhase> nextPhaseSupplier) {
367367
/* This is the main search phase transition where we move to the next phase. If all shards
368368
* failed or if there was a failure and partial results are not allowed, then we immediately
369369
* fail. Otherwise we continue to the next phase.
@@ -374,7 +374,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
374374
Throwable cause = shardSearchFailures.length == 0
375375
? null
376376
: ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0];
377-
logger.debug(() -> "All shards failed for phase: [" + currentPhase.getName() + "]", cause);
377+
logger.debug(() -> "All shards failed for phase: [" + currentPhase + "]", cause);
378378
onPhaseFailure(currentPhase, "all shards failed", cause);
379379
} else {
380380
Boolean allowPartialResults = request.allowPartialSearchResults();
@@ -387,7 +387,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
387387
int numShardFailures = shardSearchFailures.length;
388388
shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures);
389389
Throwable cause = ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0];
390-
logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase.getName()), cause);
390+
logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase), cause);
391391
}
392392
onPhaseFailure(currentPhase, "Partial shards failure", null);
393393
} else {
@@ -400,7 +400,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
400400
successfulOps.get(),
401401
toSkipShardsIts.size(),
402402
getNumShards(),
403-
currentPhase.getName()
403+
currentPhase
404404
);
405405
}
406406
onPhaseFailure(currentPhase, "Partial shards failure (" + discrepancy + " shards unavailable)", null);
@@ -414,7 +414,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
414414
.collect(Collectors.joining(","));
415415
logger.trace(
416416
"[{}] Moving to next phase: [{}], based on results from: {} (cluster state version: {})",
417-
currentPhase.getName(),
417+
currentPhase,
418418
nextPhase.getName(),
419419
resultsFrom,
420420
clusterStateVersion
@@ -427,11 +427,11 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
427427
private void executePhase(SearchPhase phase) {
428428
try {
429429
phase.run();
430-
} catch (Exception e) {
430+
} catch (RuntimeException e) {
431431
if (logger.isDebugEnabled()) {
432432
logger.debug(() -> format("Failed to execute [%s] while moving to [%s] phase", request, phase.getName()), e);
433433
}
434-
onPhaseFailure(phase, "", e);
434+
onPhaseFailure(phase.getName(), "", e);
435435
}
436436
}
437437

@@ -686,8 +686,8 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At
686686
* @param msg an optional message
687687
* @param cause the cause of the phase failure
688688
*/
689-
public void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) {
690-
raisePhaseFailure(new SearchPhaseExecutionException(phase.getName(), msg, cause, buildShardFailures()));
689+
public void onPhaseFailure(String phase, String msg, Throwable cause) {
690+
raisePhaseFailure(new SearchPhaseExecutionException(phase, msg, cause, buildShardFailures()));
691691
}
692692

693693
/**
@@ -732,7 +732,7 @@ void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connecti
732732
* @see #onShardResult(SearchPhaseResult, SearchShardIterator)
733733
*/
734734
private void onPhaseDone() { // as a tribute to @kimchy aka. finishHim()
735-
executeNextPhase(this, this::getNextPhase);
735+
executeNextPhase(getName(), this::getNextPhase);
736736
}
737737

738738
/**

server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

Lines changed: 126 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88
*/
99
package org.elasticsearch.action.search;
1010

11+
import org.apache.lucene.index.Term;
12+
import org.apache.lucene.search.CollectionStatistics;
1113
import org.apache.lucene.search.ScoreDoc;
14+
import org.apache.lucene.search.TermStatistics;
15+
import org.apache.lucene.search.TopDocs;
16+
import org.apache.lucene.search.TotalHits;
1217
import org.apache.lucene.search.join.ScoreMode;
18+
import org.apache.lucene.util.SetOnce;
19+
import org.elasticsearch.client.internal.Client;
1320
import org.elasticsearch.common.lucene.Lucene;
1421
import org.elasticsearch.index.query.NestedQueryBuilder;
1522
import org.elasticsearch.index.query.QueryBuilder;
@@ -27,9 +34,11 @@
2734
import org.elasticsearch.transport.Transport;
2835

2936
import java.util.ArrayList;
37+
import java.util.Collection;
3038
import java.util.Comparator;
39+
import java.util.HashMap;
3140
import java.util.List;
32-
import java.util.function.Function;
41+
import java.util.Map;
3342

3443
/**
3544
* This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all
@@ -38,53 +47,50 @@
3847
* operation.
3948
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
4049
*/
41-
final class DfsQueryPhase extends SearchPhase {
50+
class DfsQueryPhase extends SearchPhase {
51+
52+
public static final String NAME = "dfs_query";
53+
4254
private final SearchPhaseResults<SearchPhaseResult> queryResult;
43-
private final List<DfsSearchResult> searchResults;
44-
private final AggregatedDfs dfs;
45-
private final List<DfsKnnResults> knnResults;
46-
private final Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
55+
private final Client client;
4756
private final AbstractSearchAsyncAction<?> context;
48-
private final SearchTransportService searchTransportService;
4957
private final SearchProgressListener progressListener;
5058

51-
DfsQueryPhase(
52-
List<DfsSearchResult> searchResults,
53-
AggregatedDfs dfs,
54-
List<DfsKnnResults> knnResults,
55-
SearchPhaseResults<SearchPhaseResult> queryResult,
56-
Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
57-
AbstractSearchAsyncAction<?> context
58-
) {
59-
super("dfs_query");
59+
DfsQueryPhase(SearchPhaseResults<SearchPhaseResult> queryResult, Client client, AbstractSearchAsyncAction<?> context) {
60+
super(NAME);
6061
this.progressListener = context.getTask().getProgressListener();
6162
this.queryResult = queryResult;
62-
this.searchResults = searchResults;
63-
this.dfs = dfs;
64-
this.knnResults = knnResults;
65-
this.nextPhaseFactory = nextPhaseFactory;
63+
this.client = client;
6664
this.context = context;
67-
this.searchTransportService = context.getSearchTransport();
6865
}
6966

67+
// protected for testing
68+
protected SearchPhase nextPhase(AggregatedDfs dfs) {
69+
return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs);
70+
}
71+
72+
@SuppressWarnings("unchecked")
7073
@Override
71-
public void run() {
74+
protected void run() {
75+
List<DfsSearchResult> searchResults = (List<DfsSearchResult>) context.results.getAtomicArray().asList();
76+
AggregatedDfs dfs = aggregateDfs(searchResults);
7277
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
7378
// to free up memory early
7479
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(
7580
queryResult,
7681
searchResults.size(),
77-
() -> context.executeNextPhase(this, () -> nextPhaseFactory.apply(queryResult)),
82+
() -> context.executeNextPhase(NAME, () -> nextPhase(dfs)),
7883
context
7984
);
8085

86+
List<DfsKnnResults> knnResults = mergeKnnResults(context.getRequest(), searchResults);
8187
for (final DfsSearchResult dfsResult : searchResults) {
8288
final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget();
8389
final int shardIndex = dfsResult.getShardIndex();
8490
QuerySearchRequest querySearchRequest = new QuerySearchRequest(
8591
context.getOriginalIndices(shardIndex),
8692
dfsResult.getContextId(),
87-
rewriteShardSearchRequest(dfsResult.getShardSearchRequest()),
93+
rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()),
8894
dfs
8995
);
9096
final Transport.Connection connection;
@@ -94,19 +100,16 @@ public void run() {
94100
shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter);
95101
continue;
96102
}
97-
searchTransportService.sendExecuteQuery(
98-
connection,
99-
querySearchRequest,
100-
context.getTask(),
101-
new SearchActionListener<>(shardTarget, shardIndex) {
103+
context.getSearchTransport()
104+
.sendExecuteQuery(connection, querySearchRequest, context.getTask(), new SearchActionListener<>(shardTarget, shardIndex) {
102105

103106
@Override
104107
protected void innerOnResponse(QuerySearchResult response) {
105108
try {
106109
response.setSearchProfileDfsPhaseResult(dfsResult.searchProfileDfsPhaseResult());
107110
counter.onResult(response);
108111
} catch (Exception e) {
109-
context.onPhaseFailure(DfsQueryPhase.this, "", e);
112+
context.onPhaseFailure(NAME, "", e);
110113
}
111114
}
112115

@@ -123,8 +126,7 @@ public void onFailure(Exception exception) {
123126
}
124127
}
125128
}
126-
}
127-
);
129+
});
128130
}
129131
}
130132

@@ -141,7 +143,7 @@ private void shardFailure(
141143
}
142144

143145
// package private for testing
144-
ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
146+
ShardSearchRequest rewriteShardSearchRequest(List<DfsKnnResults> knnResults, ShardSearchRequest request) {
145147
SearchSourceBuilder source = request.source();
146148
if (source == null || source.knnSearch().isEmpty()) {
147149
return request;
@@ -177,4 +179,95 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
177179

178180
return request;
179181
}
182+
183+
private static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
184+
if (request.hasKnnSearch() == false) {
185+
return null;
186+
}
187+
SearchSourceBuilder source = request.source();
188+
List<List<TopDocs>> topDocsLists = new ArrayList<>(source.knnSearch().size());
189+
List<SetOnce<String>> nestedPath = new ArrayList<>(source.knnSearch().size());
190+
for (int i = 0; i < source.knnSearch().size(); i++) {
191+
topDocsLists.add(new ArrayList<>());
192+
nestedPath.add(new SetOnce<>());
193+
}
194+
195+
for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
196+
if (dfsSearchResult.knnResults() != null) {
197+
for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
198+
DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
199+
ScoreDoc[] scoreDocs = knnResults.scoreDocs();
200+
TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO);
201+
TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
202+
SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
203+
topDocsLists.get(i).add(shardTopDocs);
204+
nestedPath.get(i).trySet(knnResults.getNestedPath());
205+
}
206+
}
207+
}
208+
209+
List<DfsKnnResults> mergedResults = new ArrayList<>(source.knnSearch().size());
210+
for (int i = 0; i < source.knnSearch().size(); i++) {
211+
TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
212+
mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs));
213+
}
214+
return mergedResults;
215+
}
216+
217+
private static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
218+
Map<Term, TermStatistics> termStatistics = new HashMap<>();
219+
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
220+
long aggMaxDoc = 0;
221+
for (DfsSearchResult lEntry : results) {
222+
final Term[] terms = lEntry.terms();
223+
final TermStatistics[] stats = lEntry.termStatistics();
224+
assert terms.length == stats.length;
225+
for (int i = 0; i < terms.length; i++) {
226+
assert terms[i] != null;
227+
if (stats[i] == null) {
228+
continue;
229+
}
230+
TermStatistics existing = termStatistics.get(terms[i]);
231+
if (existing != null) {
232+
assert terms[i].bytes().equals(existing.term());
233+
termStatistics.put(
234+
terms[i],
235+
new TermStatistics(
236+
existing.term(),
237+
existing.docFreq() + stats[i].docFreq(),
238+
existing.totalTermFreq() + stats[i].totalTermFreq()
239+
)
240+
);
241+
} else {
242+
termStatistics.put(terms[i], stats[i]);
243+
}
244+
245+
}
246+
247+
assert lEntry.fieldStatistics().containsKey(null) == false;
248+
for (var entry : lEntry.fieldStatistics().entrySet()) {
249+
String key = entry.getKey();
250+
CollectionStatistics value = entry.getValue();
251+
if (value == null) {
252+
continue;
253+
}
254+
assert key != null;
255+
CollectionStatistics existing = fieldStatistics.get(key);
256+
if (existing != null) {
257+
CollectionStatistics merged = new CollectionStatistics(
258+
key,
259+
existing.maxDoc() + value.maxDoc(),
260+
existing.docCount() + value.docCount(),
261+
existing.sumTotalTermFreq() + value.sumTotalTermFreq(),
262+
existing.sumDocFreq() + value.sumDocFreq()
263+
);
264+
fieldStatistics.put(key, merged);
265+
} else {
266+
fieldStatistics.put(key, value);
267+
}
268+
}
269+
aggMaxDoc += lEntry.maxDoc();
270+
}
271+
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
272+
}
180273
}

0 commit comments

Comments
 (0)