Skip to content

Commit d2da66d

Browse files
Use inheritance instead of composition to simplify search phase transitions
We only need the extensibility for testing and it's a lot easier to reason about the code if we have explicit methods instead of overly complicated composition with lots of redundant references being retained all over the place. -> lets simplify to inheritance and get shorter code that performs more predictably (especially when it comes to memory) as a first step. This also opens up the possibility of further simplifications and removing more retained state/memory as we go through the search phases.
1 parent ad1938d commit d2da66d

File tree

8 files changed

+151
-182
lines changed

8 files changed

+151
-182
lines changed

server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import org.apache.lucene.search.ScoreDoc;
1212
import org.apache.lucene.search.join.ScoreMode;
13+
import org.elasticsearch.client.internal.Client;
1314
import org.elasticsearch.common.lucene.Lucene;
1415
import org.elasticsearch.index.query.NestedQueryBuilder;
1516
import org.elasticsearch.index.query.QueryBuilder;
@@ -29,7 +30,6 @@
2930
import java.util.ArrayList;
3031
import java.util.Comparator;
3132
import java.util.List;
32-
import java.util.function.Function;
3333

3434
/**
3535
* This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all
@@ -38,43 +38,47 @@
3838
* operation.
3939
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
4040
*/
41-
final class DfsQueryPhase extends SearchPhase {
41+
class DfsQueryPhase extends SearchPhase {
4242
private final SearchPhaseResults<SearchPhaseResult> queryResult;
4343
private final List<DfsSearchResult> searchResults;
4444
private final AggregatedDfs dfs;
4545
private final List<DfsKnnResults> knnResults;
46-
private final Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
46+
private final Client client;
4747
private final AbstractSearchAsyncAction<?> context;
4848
private final SearchTransportService searchTransportService;
4949
private final SearchProgressListener progressListener;
5050

5151
DfsQueryPhase(
5252
List<DfsSearchResult> searchResults,
53-
AggregatedDfs dfs,
5453
List<DfsKnnResults> knnResults,
5554
SearchPhaseResults<SearchPhaseResult> queryResult,
56-
Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
55+
Client client,
5756
AbstractSearchAsyncAction<?> context
5857
) {
5958
super("dfs_query");
6059
this.progressListener = context.getTask().getProgressListener();
6160
this.queryResult = queryResult;
6261
this.searchResults = searchResults;
63-
this.dfs = dfs;
62+
this.dfs = SearchPhaseController.aggregateDfs(searchResults);
6463
this.knnResults = knnResults;
65-
this.nextPhaseFactory = nextPhaseFactory;
64+
this.client = client;
6665
this.context = context;
6766
this.searchTransportService = context.getSearchTransport();
6867
}
6968

69+
// protected for testing
70+
protected SearchPhase nextPhase() {
71+
return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs);
72+
}
73+
7074
@Override
7175
public void run() {
7276
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
7377
// to free up memory early
7478
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(
7579
queryResult,
7680
searchResults.size(),
77-
() -> context.executeNextPhase(this, () -> nextPhaseFactory.apply(queryResult)),
81+
() -> context.executeNextPhase(this, this::nextPhase),
7882
context
7983
);
8084

server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,44 @@
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.common.Strings;
1414
import org.elasticsearch.common.util.Maps;
15+
import org.elasticsearch.common.util.concurrent.AtomicArray;
1516
import org.elasticsearch.index.query.BoolQueryBuilder;
1617
import org.elasticsearch.index.query.InnerHitBuilder;
1718
import org.elasticsearch.index.query.QueryBuilder;
1819
import org.elasticsearch.index.query.QueryBuilders;
1920
import org.elasticsearch.search.SearchHit;
2021
import org.elasticsearch.search.SearchHits;
22+
import org.elasticsearch.search.SearchPhaseResult;
2123
import org.elasticsearch.search.builder.SearchSourceBuilder;
2224
import org.elasticsearch.search.collapse.CollapseBuilder;
2325

2426
import java.util.Iterator;
2527
import java.util.List;
26-
import java.util.function.Supplier;
2728

2829
/**
2930
* This search phase is an optional phase that will be executed once all hits are fetched from the shards that executes
3031
* field-collapsing on the inner hits. This phase only executes if field collapsing is requested in the search request and otherwise
3132
* forwards to the next phase immediately.
3233
*/
33-
final class ExpandSearchPhase extends SearchPhase {
34+
class ExpandSearchPhase extends SearchPhase {
3435
private final AbstractSearchAsyncAction<?> context;
35-
private final SearchHits searchHits;
36-
private final Supplier<SearchPhase> nextPhase;
36+
private final SearchResponseSections searchResponseSections;
37+
private final AtomicArray<SearchPhaseResult> queryPhaseResults;
3738

38-
ExpandSearchPhase(AbstractSearchAsyncAction<?> context, SearchHits searchHits, Supplier<SearchPhase> nextPhase) {
39+
ExpandSearchPhase(
40+
AbstractSearchAsyncAction<?> context,
41+
SearchResponseSections searchResponseSections,
42+
AtomicArray<SearchPhaseResult> queryPhaseResults
43+
) {
3944
super("expand");
4045
this.context = context;
41-
this.searchHits = searchHits;
42-
this.nextPhase = nextPhase;
46+
this.searchResponseSections = searchResponseSections;
47+
this.queryPhaseResults = queryPhaseResults;
48+
}
49+
50+
// protected for tests
51+
protected SearchPhase nextPhase() {
52+
return new FetchLookupFieldsPhase(context, searchResponseSections, queryPhaseResults);
4353
}
4454

4555
/**
@@ -52,14 +62,15 @@ private boolean isCollapseRequest() {
5262

5363
@Override
5464
public void run() {
65+
var searchHits = searchResponseSections.hits();
5566
if (isCollapseRequest() == false || searchHits.getHits().length == 0) {
5667
onPhaseDone();
5768
} else {
58-
doRun();
69+
doRun(searchHits);
5970
}
6071
}
6172

62-
private void doRun() {
73+
private void doRun(SearchHits searchHits) {
6374
SearchRequest searchRequest = context.getRequest();
6475
CollapseBuilder collapseBuilder = searchRequest.source().collapse();
6576
final List<InnerHitBuilder> innerHitBuilders = collapseBuilder.getInnerHits();
@@ -168,6 +179,6 @@ private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilde
168179
}
169180

170181
private void onPhaseDone() {
171-
context.executeNextPhase(this, nextPhase);
182+
context.executeNextPhase(this, this::nextPhase);
172183
}
173184
}

server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ final class FetchLookupFieldsPhase extends SearchPhase {
4848
this.queryResults = queryResults;
4949
}
5050

51-
private record Cluster(String clusterAlias, List<SearchHit> hitsWithLookupFields, List<LookupField> lookupFields) {
52-
53-
}
51+
private record Cluster(String clusterAlias, List<SearchHit> hitsWithLookupFields, List<LookupField> lookupFields) {}
5452

5553
private static List<Cluster> groupLookupFieldsByClusterAlias(SearchHits searchHits) {
5654
final Map<String, List<SearchHit>> perClusters = new HashMap<>();
@@ -77,7 +75,7 @@ private static List<Cluster> groupLookupFieldsByClusterAlias(SearchHits searchHi
7775
public void run() {
7876
final List<Cluster> clusters = groupLookupFieldsByClusterAlias(searchResponse.hits);
7977
if (clusters.isEmpty()) {
80-
context.sendSearchResponse(searchResponse, queryResults);
78+
sendResponse();
8179
return;
8280
}
8381
doRun(clusters);
@@ -129,9 +127,9 @@ public void onResponse(MultiSearchResponse items) {
129127
}
130128
}
131129
if (failure != null) {
132-
context.onPhaseFailure(FetchLookupFieldsPhase.this, "failed to fetch lookup fields", failure);
130+
onFailure(failure);
133131
} else {
134-
context.sendSearchResponse(searchResponse, queryResults);
132+
sendResponse();
135133
}
136134
}
137135

@@ -141,4 +139,8 @@ public void onFailure(Exception e) {
141139
}
142140
});
143141
}
142+
143+
private void sendResponse() {
144+
context.sendSearchResponse(searchResponse, queryResults);
145+
}
144146
}

server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@
2727
import java.util.HashMap;
2828
import java.util.List;
2929
import java.util.Map;
30-
import java.util.function.BiFunction;
3130

3231
/**
3332
* This search phase merges the query results from the previous phase together and calculates the topN hits for this search.
3433
* Then it reaches out to all relevant shards to fetch the topN hits.
3534
*/
36-
final class FetchSearchPhase extends SearchPhase {
35+
class FetchSearchPhase extends SearchPhase {
3736
private final AtomicArray<SearchPhaseResult> searchPhaseShardResults;
38-
private final BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
3937
private final AbstractSearchAsyncAction<?> context;
4038
private final Logger logger;
4139
private final SearchProgressListener progressListener;
@@ -49,26 +47,6 @@ final class FetchSearchPhase extends SearchPhase {
4947
AggregatedDfs aggregatedDfs,
5048
AbstractSearchAsyncAction<?> context,
5149
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase
52-
) {
53-
this(
54-
resultConsumer,
55-
aggregatedDfs,
56-
context,
57-
reducedQueryPhase,
58-
(response, queryPhaseResults) -> new ExpandSearchPhase(
59-
context,
60-
response.hits,
61-
() -> new FetchLookupFieldsPhase(context, response, queryPhaseResults)
62-
)
63-
);
64-
}
65-
66-
FetchSearchPhase(
67-
SearchPhaseResults<SearchPhaseResult> resultConsumer,
68-
AggregatedDfs aggregatedDfs,
69-
AbstractSearchAsyncAction<?> context,
70-
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
71-
BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory
7250
) {
7351
super("fetch");
7452
if (context.getNumShards() != resultConsumer.getNumShards()) {
@@ -81,16 +59,20 @@ final class FetchSearchPhase extends SearchPhase {
8159
}
8260
this.searchPhaseShardResults = resultConsumer.getAtomicArray();
8361
this.aggregatedDfs = aggregatedDfs;
84-
this.nextPhaseFactory = nextPhaseFactory;
8562
this.context = context;
8663
this.logger = context.getLogger();
8764
this.progressListener = context.getTask().getProgressListener();
8865
this.reducedQueryPhase = reducedQueryPhase;
8966
this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null;
9067
}
9168

69+
// protected for tests
70+
protected SearchPhase nextPhase(SearchResponseSections searchResponseSections, AtomicArray<SearchPhaseResult> queryPhaseResults) {
71+
return new ExpandSearchPhase(context, searchResponseSections, queryPhaseResults);
72+
}
73+
9274
@Override
93-
public void run() {
75+
public final void run() {
9476
context.execute(new AbstractRunnable() {
9577

9678
@Override
@@ -112,7 +94,7 @@ private void innerRun() throws Exception {
11294
final int numShards = context.getNumShards();
11395
// Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might
11496
// still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase.
115-
final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1
97+
final boolean queryAndFetchOptimization = numShards == 1
11698
&& context.getRequest().hasKnnSearch() == false
11799
&& reducedQueryPhase.queryPhaseRankCoordinatorContext() == null
118100
&& (context.getRequest().source() == null || context.getRequest().source().rankBuilder() == null);
@@ -127,7 +109,7 @@ private void innerRun() throws Exception {
127109
// we have to release contexts here to free up resources
128110
searchPhaseShardResults.asList()
129111
.forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context));
130-
moveToNextPhase(new AtomicArray<>(numShards), reducedQueryPhase);
112+
moveToNextPhase(new AtomicArray<>(0), reducedQueryPhase);
131113
} else {
132114
innerRunFetch(scoreDocs, numShards, reducedQueryPhase);
133115
}
@@ -272,7 +254,7 @@ private void moveToNextPhase(
272254
context.executeNextPhase(this, () -> {
273255
var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr);
274256
context.addReleasable(resp::decRef);
275-
return nextPhaseFactory.apply(resp, searchPhaseShardResults);
257+
return nextPhase(resp, searchPhaseShardResults);
276258
});
277259
}
278260

server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1818
import org.elasticsearch.search.SearchPhaseResult;
1919
import org.elasticsearch.search.SearchShardTarget;
20-
import org.elasticsearch.search.dfs.AggregatedDfs;
2120
import org.elasticsearch.search.dfs.DfsKnnResults;
2221
import org.elasticsearch.search.dfs.DfsSearchResult;
2322
import org.elasticsearch.search.internal.AliasFilter;
@@ -93,16 +92,8 @@ protected void executePhaseOnShard(
9392
@Override
9493
protected SearchPhase getNextPhase() {
9594
final List<DfsSearchResult> dfsSearchResults = results.getAtomicArray().asList();
96-
final AggregatedDfs aggregatedDfs = SearchPhaseController.aggregateDfs(dfsSearchResults);
9795
final List<DfsKnnResults> mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults);
98-
return new DfsQueryPhase(
99-
dfsSearchResults,
100-
aggregatedDfs,
101-
mergedKnnResults,
102-
queryPhaseResultConsumer,
103-
(queryResults) -> SearchQueryThenFetchAsyncAction.nextPhase(client, this, queryResults, aggregatedDfs),
104-
this
105-
);
96+
return new DfsQueryPhase(dfsSearchResults, mergedKnnResults, queryPhaseResultConsumer, client, this);
10697
}
10798

10899
@Override

server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,7 @@ public void sendExecuteQuery(
139139
exc -> {}
140140
)
141141
) {
142-
DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") {
143-
@Override
144-
public void run() throws IOException {
145-
responseRef.set(((QueryPhaseResultConsumer) response).results);
146-
}
147-
}, mockSearchPhaseContext);
142+
DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef);
148143
assertEquals("dfs_query", phase.getName());
149144
phase.run();
150145
mockSearchPhaseContext.assertNoFailure();
@@ -225,12 +220,7 @@ public void sendExecuteQuery(
225220
exc -> {}
226221
)
227222
) {
228-
DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") {
229-
@Override
230-
public void run() throws IOException {
231-
responseRef.set(((QueryPhaseResultConsumer) response).results);
232-
}
233-
}, mockSearchPhaseContext);
223+
DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef);
234224
assertEquals("dfs_query", phase.getName());
235225
phase.run();
236226
mockSearchPhaseContext.assertNoFailure();
@@ -313,12 +303,7 @@ public void sendExecuteQuery(
313303
exc -> {}
314304
)
315305
) {
316-
DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") {
317-
@Override
318-
public void run() throws IOException {
319-
responseRef.set(((QueryPhaseResultConsumer) response).results);
320-
}
321-
}, mockSearchPhaseContext);
306+
DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef);
322307
assertEquals("dfs_query", phase.getName());
323308
phase.run();
324309
assertThat(mockSearchPhaseContext.failures, hasSize(1));
@@ -328,6 +313,25 @@ public void run() throws IOException {
328313
}
329314
}
330315

316+
private static DfsQueryPhase makeDfsPhase(
317+
AtomicArray<DfsSearchResult> results,
318+
SearchPhaseResults<SearchPhaseResult> consumer,
319+
MockSearchPhaseContext mockSearchPhaseContext,
320+
AtomicReference<AtomicArray<SearchPhaseResult>> responseRef
321+
) {
322+
return new DfsQueryPhase(results.asList(), null, consumer, null, mockSearchPhaseContext) {
323+
@Override
324+
protected SearchPhase nextPhase() {
325+
return new SearchPhase("test") {
326+
@Override
327+
public void run() {
328+
responseRef.set(((QueryPhaseResultConsumer) consumer).results);
329+
}
330+
};
331+
}
332+
};
333+
}
334+
331335
public void testRewriteShardSearchRequestWithRank() {
332336
List<DfsKnnResults> dkrs = List.of(
333337
new DfsKnnResults(null, new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1), new ScoreDoc(7, 0.1f, 2) }),
@@ -338,7 +342,7 @@ public void testRewriteShardSearchRequestWithRank() {
338342
);
339343
MockSearchPhaseContext mspc = new MockSearchPhaseContext(2);
340344
mspc.searchTransport = new SearchTransportService(null, null, null);
341-
DfsQueryPhase dqp = new DfsQueryPhase(null, null, dkrs, mock(QueryPhaseResultConsumer.class), null, mspc);
345+
DfsQueryPhase dqp = new DfsQueryPhase(List.of(), dkrs, mock(QueryPhaseResultConsumer.class), null, mspc);
342346

343347
QueryBuilder bm25 = new TermQueryBuilder("field", "term");
344348
SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25)

0 commit comments

Comments
 (0)