Skip to content

Commit cae7f0a

Browse files
Use inheritance instead of composition to simplify search phase transitions (#119272)
We only need the extensibility for testing and it's a lot easier to reason about the code if we have explicit methods instead of overly complicated composition with lots of redundant references being retained all over the place. -> lets simplify to inheritance and get shorter code that performs more predictably (especially when it comes to memory) as a first step. This also opens up the possibility of further simplifications and removing more retained state/memory as we go through the search phases.
1 parent 187b192 commit cae7f0a

File tree

8 files changed

+262
-303
lines changed

8 files changed

+262
-303
lines changed

server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88
*/
99
package org.elasticsearch.action.search;
1010

11+
import org.apache.lucene.index.Term;
12+
import org.apache.lucene.search.CollectionStatistics;
1113
import org.apache.lucene.search.ScoreDoc;
14+
import org.apache.lucene.search.TermStatistics;
15+
import org.apache.lucene.search.TopDocs;
16+
import org.apache.lucene.search.TotalHits;
1217
import org.apache.lucene.search.join.ScoreMode;
18+
import org.apache.lucene.util.SetOnce;
19+
import org.elasticsearch.client.internal.Client;
1320
import org.elasticsearch.common.lucene.Lucene;
1421
import org.elasticsearch.index.query.NestedQueryBuilder;
1522
import org.elasticsearch.index.query.QueryBuilder;
@@ -27,9 +34,11 @@
2734
import org.elasticsearch.transport.Transport;
2835

2936
import java.util.ArrayList;
37+
import java.util.Collection;
3038
import java.util.Comparator;
39+
import java.util.HashMap;
3140
import java.util.List;
32-
import java.util.function.Function;
41+
import java.util.Map;
3342

3443
/**
3544
* This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all
@@ -38,56 +47,50 @@
3847
* operation.
3948
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
4049
*/
41-
final class DfsQueryPhase extends SearchPhase {
50+
class DfsQueryPhase extends SearchPhase {
4251

4352
public static final String NAME = "dfs_query";
4453

4554
private final SearchPhaseResults<SearchPhaseResult> queryResult;
46-
private final List<DfsSearchResult> searchResults;
47-
private final AggregatedDfs dfs;
48-
private final List<DfsKnnResults> knnResults;
49-
private final Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
55+
private final Client client;
5056
private final AbstractSearchAsyncAction<?> context;
51-
private final SearchTransportService searchTransportService;
5257
private final SearchProgressListener progressListener;
5358

54-
DfsQueryPhase(
55-
List<DfsSearchResult> searchResults,
56-
AggregatedDfs dfs,
57-
List<DfsKnnResults> knnResults,
58-
SearchPhaseResults<SearchPhaseResult> queryResult,
59-
Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
60-
AbstractSearchAsyncAction<?> context
61-
) {
59+
DfsQueryPhase(SearchPhaseResults<SearchPhaseResult> queryResult, Client client, AbstractSearchAsyncAction<?> context) {
6260
super(NAME);
6361
this.progressListener = context.getTask().getProgressListener();
6462
this.queryResult = queryResult;
65-
this.searchResults = searchResults;
66-
this.dfs = dfs;
67-
this.knnResults = knnResults;
68-
this.nextPhaseFactory = nextPhaseFactory;
63+
this.client = client;
6964
this.context = context;
70-
this.searchTransportService = context.getSearchTransport();
7165
}
7266

67+
// protected for testing
68+
protected SearchPhase nextPhase(AggregatedDfs dfs) {
69+
return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs);
70+
}
71+
72+
@SuppressWarnings("unchecked")
7373
@Override
7474
protected void run() {
75+
List<DfsSearchResult> searchResults = (List<DfsSearchResult>) context.results.getAtomicArray().asList();
76+
AggregatedDfs dfs = aggregateDfs(searchResults);
7577
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
7678
// to free up memory early
7779
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(
7880
queryResult,
7981
searchResults.size(),
80-
() -> context.executeNextPhase(NAME, () -> nextPhaseFactory.apply(queryResult)),
82+
() -> context.executeNextPhase(NAME, () -> nextPhase(dfs)),
8183
context
8284
);
8385

86+
List<DfsKnnResults> knnResults = mergeKnnResults(context.getRequest(), searchResults);
8487
for (final DfsSearchResult dfsResult : searchResults) {
8588
final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget();
8689
final int shardIndex = dfsResult.getShardIndex();
8790
QuerySearchRequest querySearchRequest = new QuerySearchRequest(
8891
context.getOriginalIndices(shardIndex),
8992
dfsResult.getContextId(),
90-
rewriteShardSearchRequest(dfsResult.getShardSearchRequest()),
93+
rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()),
9194
dfs
9295
);
9396
final Transport.Connection connection;
@@ -97,11 +100,8 @@ protected void run() {
97100
shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter);
98101
continue;
99102
}
100-
searchTransportService.sendExecuteQuery(
101-
connection,
102-
querySearchRequest,
103-
context.getTask(),
104-
new SearchActionListener<>(shardTarget, shardIndex) {
103+
context.getSearchTransport()
104+
.sendExecuteQuery(connection, querySearchRequest, context.getTask(), new SearchActionListener<>(shardTarget, shardIndex) {
105105

106106
@Override
107107
protected void innerOnResponse(QuerySearchResult response) {
@@ -126,8 +126,7 @@ public void onFailure(Exception exception) {
126126
}
127127
}
128128
}
129-
}
130-
);
129+
});
131130
}
132131
}
133132

@@ -144,7 +143,7 @@ private void shardFailure(
144143
}
145144

146145
// package private for testing
147-
ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
146+
ShardSearchRequest rewriteShardSearchRequest(List<DfsKnnResults> knnResults, ShardSearchRequest request) {
148147
SearchSourceBuilder source = request.source();
149148
if (source == null || source.knnSearch().isEmpty()) {
150149
return request;
@@ -180,4 +179,95 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
180179

181180
return request;
182181
}
182+
183+
private static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
184+
if (request.hasKnnSearch() == false) {
185+
return null;
186+
}
187+
SearchSourceBuilder source = request.source();
188+
List<List<TopDocs>> topDocsLists = new ArrayList<>(source.knnSearch().size());
189+
List<SetOnce<String>> nestedPath = new ArrayList<>(source.knnSearch().size());
190+
for (int i = 0; i < source.knnSearch().size(); i++) {
191+
topDocsLists.add(new ArrayList<>());
192+
nestedPath.add(new SetOnce<>());
193+
}
194+
195+
for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
196+
if (dfsSearchResult.knnResults() != null) {
197+
for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
198+
DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
199+
ScoreDoc[] scoreDocs = knnResults.scoreDocs();
200+
TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO);
201+
TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
202+
SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
203+
topDocsLists.get(i).add(shardTopDocs);
204+
nestedPath.get(i).trySet(knnResults.getNestedPath());
205+
}
206+
}
207+
}
208+
209+
List<DfsKnnResults> mergedResults = new ArrayList<>(source.knnSearch().size());
210+
for (int i = 0; i < source.knnSearch().size(); i++) {
211+
TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
212+
mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs));
213+
}
214+
return mergedResults;
215+
}
216+
217+
private static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
218+
Map<Term, TermStatistics> termStatistics = new HashMap<>();
219+
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
220+
long aggMaxDoc = 0;
221+
for (DfsSearchResult lEntry : results) {
222+
final Term[] terms = lEntry.terms();
223+
final TermStatistics[] stats = lEntry.termStatistics();
224+
assert terms.length == stats.length;
225+
for (int i = 0; i < terms.length; i++) {
226+
assert terms[i] != null;
227+
if (stats[i] == null) {
228+
continue;
229+
}
230+
TermStatistics existing = termStatistics.get(terms[i]);
231+
if (existing != null) {
232+
assert terms[i].bytes().equals(existing.term());
233+
termStatistics.put(
234+
terms[i],
235+
new TermStatistics(
236+
existing.term(),
237+
existing.docFreq() + stats[i].docFreq(),
238+
existing.totalTermFreq() + stats[i].totalTermFreq()
239+
)
240+
);
241+
} else {
242+
termStatistics.put(terms[i], stats[i]);
243+
}
244+
245+
}
246+
247+
assert lEntry.fieldStatistics().containsKey(null) == false;
248+
for (var entry : lEntry.fieldStatistics().entrySet()) {
249+
String key = entry.getKey();
250+
CollectionStatistics value = entry.getValue();
251+
if (value == null) {
252+
continue;
253+
}
254+
assert key != null;
255+
CollectionStatistics existing = fieldStatistics.get(key);
256+
if (existing != null) {
257+
CollectionStatistics merged = new CollectionStatistics(
258+
key,
259+
existing.maxDoc() + value.maxDoc(),
260+
existing.docCount() + value.docCount(),
261+
existing.sumTotalTermFreq() + value.sumTotalTermFreq(),
262+
existing.sumDocFreq() + value.sumDocFreq()
263+
);
264+
fieldStatistics.put(key, merged);
265+
} else {
266+
fieldStatistics.put(key, value);
267+
}
268+
}
269+
aggMaxDoc += lEntry.maxDoc();
270+
}
271+
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
272+
}
183273
}

server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,47 @@
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.common.Strings;
1414
import org.elasticsearch.common.util.Maps;
15+
import org.elasticsearch.common.util.concurrent.AtomicArray;
1516
import org.elasticsearch.index.query.BoolQueryBuilder;
1617
import org.elasticsearch.index.query.InnerHitBuilder;
1718
import org.elasticsearch.index.query.QueryBuilder;
1819
import org.elasticsearch.index.query.QueryBuilders;
1920
import org.elasticsearch.search.SearchHit;
2021
import org.elasticsearch.search.SearchHits;
22+
import org.elasticsearch.search.SearchPhaseResult;
2123
import org.elasticsearch.search.builder.SearchSourceBuilder;
2224
import org.elasticsearch.search.collapse.CollapseBuilder;
2325

2426
import java.util.Iterator;
2527
import java.util.List;
26-
import java.util.function.Supplier;
2728

2829
/**
2930
* This search phase is an optional phase that will be executed once all hits are fetched from the shards that executes
3031
* field-collapsing on the inner hits. This phase only executes if field collapsing is requested in the search request and otherwise
3132
* forwards to the next phase immediately.
3233
*/
33-
final class ExpandSearchPhase extends SearchPhase {
34+
class ExpandSearchPhase extends SearchPhase {
3435

3536
static final String NAME = "expand";
3637

3738
private final AbstractSearchAsyncAction<?> context;
38-
private final SearchHits searchHits;
39-
private final Supplier<SearchPhase> nextPhase;
39+
private final SearchResponseSections searchResponseSections;
40+
private final AtomicArray<SearchPhaseResult> queryPhaseResults;
4041

41-
ExpandSearchPhase(AbstractSearchAsyncAction<?> context, SearchHits searchHits, Supplier<SearchPhase> nextPhase) {
42+
ExpandSearchPhase(
43+
AbstractSearchAsyncAction<?> context,
44+
SearchResponseSections searchResponseSections,
45+
AtomicArray<SearchPhaseResult> queryPhaseResults
46+
) {
4247
super(NAME);
4348
this.context = context;
44-
this.searchHits = searchHits;
45-
this.nextPhase = nextPhase;
49+
this.searchResponseSections = searchResponseSections;
50+
this.queryPhaseResults = queryPhaseResults;
51+
}
52+
53+
// protected for tests
54+
protected SearchPhase nextPhase() {
55+
return new FetchLookupFieldsPhase(context, searchResponseSections, queryPhaseResults);
4656
}
4757

4858
/**
@@ -55,14 +65,15 @@ private boolean isCollapseRequest() {
5565

5666
@Override
5767
protected void run() {
68+
var searchHits = searchResponseSections.hits();
5869
if (isCollapseRequest() == false || searchHits.getHits().length == 0) {
5970
onPhaseDone();
6071
} else {
61-
doRun();
72+
doRun(searchHits);
6273
}
6374
}
6475

65-
private void doRun() {
76+
private void doRun(SearchHits searchHits) {
6677
SearchRequest searchRequest = context.getRequest();
6778
CollapseBuilder collapseBuilder = searchRequest.source().collapse();
6879
final List<InnerHitBuilder> innerHitBuilders = collapseBuilder.getInnerHits();
@@ -171,6 +182,6 @@ private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilde
171182
}
172183

173184
private void onPhaseDone() {
174-
context.executeNextPhase(NAME, nextPhase);
185+
context.executeNextPhase(NAME, this::nextPhase);
175186
}
176187
}

server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ final class FetchLookupFieldsPhase extends SearchPhase {
5151
this.queryResults = queryResults;
5252
}
5353

54-
private record Cluster(String clusterAlias, List<SearchHit> hitsWithLookupFields, List<LookupField> lookupFields) {
55-
56-
}
54+
private record Cluster(String clusterAlias, List<SearchHit> hitsWithLookupFields, List<LookupField> lookupFields) {}
5755

5856
private static List<Cluster> groupLookupFieldsByClusterAlias(SearchHits searchHits) {
5957
final Map<String, List<SearchHit>> perClusters = new HashMap<>();
@@ -80,7 +78,7 @@ private static List<Cluster> groupLookupFieldsByClusterAlias(SearchHits searchHi
8078
protected void run() {
8179
final List<Cluster> clusters = groupLookupFieldsByClusterAlias(searchResponse.hits);
8280
if (clusters.isEmpty()) {
83-
context.sendSearchResponse(searchResponse, queryResults);
81+
sendResponse();
8482
return;
8583
}
8684
doRun(clusters);
@@ -132,9 +130,9 @@ public void onResponse(MultiSearchResponse items) {
132130
}
133131
}
134132
if (failure != null) {
135-
context.onPhaseFailure(NAME, "failed to fetch lookup fields", failure);
133+
onFailure(failure);
136134
} else {
137-
context.sendSearchResponse(searchResponse, queryResults);
135+
sendResponse();
138136
}
139137
}
140138

@@ -144,4 +142,8 @@ public void onFailure(Exception e) {
144142
}
145143
});
146144
}
145+
146+
private void sendResponse() {
147+
context.sendSearchResponse(searchResponse, queryResults);
148+
}
147149
}

0 commit comments

Comments
 (0)