Skip to content

Commit 01920ba

Browse files
save more
1 parent d2da66d commit 01920ba

File tree

3 files changed

+22
-35
lines changed

3 files changed

+22
-35
lines changed

server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,55 +40,45 @@
4040
*/
4141
class DfsQueryPhase extends SearchPhase {
4242
private final SearchPhaseResults<SearchPhaseResult> queryResult;
43-
private final List<DfsSearchResult> searchResults;
44-
private final AggregatedDfs dfs;
45-
private final List<DfsKnnResults> knnResults;
4643
private final Client client;
4744
private final AbstractSearchAsyncAction<?> context;
48-
private final SearchTransportService searchTransportService;
4945
private final SearchProgressListener progressListener;
5046

51-
DfsQueryPhase(
52-
List<DfsSearchResult> searchResults,
53-
List<DfsKnnResults> knnResults,
54-
SearchPhaseResults<SearchPhaseResult> queryResult,
55-
Client client,
56-
AbstractSearchAsyncAction<?> context
57-
) {
47+
DfsQueryPhase(SearchPhaseResults<SearchPhaseResult> queryResult, Client client, AbstractSearchAsyncAction<?> context) {
5848
super("dfs_query");
5949
this.progressListener = context.getTask().getProgressListener();
6050
this.queryResult = queryResult;
61-
this.searchResults = searchResults;
62-
this.dfs = SearchPhaseController.aggregateDfs(searchResults);
63-
this.knnResults = knnResults;
6451
this.client = client;
6552
this.context = context;
66-
this.searchTransportService = context.getSearchTransport();
6753
}
6854

6955
// protected for testing
70-
protected SearchPhase nextPhase() {
56+
protected SearchPhase nextPhase(AggregatedDfs dfs) {
7157
return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs);
7258
}
7359

60+
@SuppressWarnings("unchecked")
7461
@Override
7562
public void run() {
63+
List<DfsSearchResult> searchResults = (List<DfsSearchResult>) context.results.getAtomicArray().asList();
64+
AggregatedDfs dfs = SearchPhaseController.aggregateDfs(searchResults);
7665
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
7766
// to free up memory early
7867
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(
7968
queryResult,
8069
searchResults.size(),
81-
() -> context.executeNextPhase(this, this::nextPhase),
70+
() -> context.executeNextPhase(this, () -> nextPhase(dfs)),
8271
context
8372
);
8473

74+
List<DfsKnnResults> knnResults = SearchPhaseController.mergeKnnResults(context.getRequest(), searchResults);
8575
for (final DfsSearchResult dfsResult : searchResults) {
8676
final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget();
8777
final int shardIndex = dfsResult.getShardIndex();
8878
QuerySearchRequest querySearchRequest = new QuerySearchRequest(
8979
context.getOriginalIndices(shardIndex),
9080
dfsResult.getContextId(),
91-
rewriteShardSearchRequest(dfsResult.getShardSearchRequest()),
81+
rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()),
9282
dfs
9383
);
9484
final Transport.Connection connection;
@@ -98,11 +88,8 @@ public void run() {
9888
shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter);
9989
continue;
10090
}
101-
searchTransportService.sendExecuteQuery(
102-
connection,
103-
querySearchRequest,
104-
context.getTask(),
105-
new SearchActionListener<>(shardTarget, shardIndex) {
91+
context.getSearchTransport()
92+
.sendExecuteQuery(connection, querySearchRequest, context.getTask(), new SearchActionListener<>(shardTarget, shardIndex) {
10693

10794
@Override
10895
protected void innerOnResponse(QuerySearchResult response) {
@@ -127,8 +114,7 @@ public void onFailure(Exception exception) {
127114
}
128115
}
129116
}
130-
}
131-
);
117+
});
132118
}
133119
}
134120

@@ -145,7 +131,7 @@ private void shardFailure(
145131
}
146132

147133
// package private for testing
148-
ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
134+
ShardSearchRequest rewriteShardSearchRequest(List<DfsKnnResults> knnResults, ShardSearchRequest request) {
149135
SearchSourceBuilder source = request.source();
150136
if (source == null || source.knnSearch().isEmpty()) {
151137
return request;

server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1818
import org.elasticsearch.search.SearchPhaseResult;
1919
import org.elasticsearch.search.SearchShardTarget;
20-
import org.elasticsearch.search.dfs.DfsKnnResults;
2120
import org.elasticsearch.search.dfs.DfsSearchResult;
2221
import org.elasticsearch.search.internal.AliasFilter;
2322
import org.elasticsearch.transport.Transport;
2423

25-
import java.util.List;
2624
import java.util.Map;
2725
import java.util.concurrent.Executor;
2826
import java.util.function.BiFunction;
@@ -91,9 +89,7 @@ protected void executePhaseOnShard(
9189

9290
@Override
9391
protected SearchPhase getNextPhase() {
94-
final List<DfsSearchResult> dfsSearchResults = results.getAtomicArray().asList();
95-
final List<DfsKnnResults> mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults);
96-
return new DfsQueryPhase(dfsSearchResults, mergedKnnResults, queryPhaseResultConsumer, client, this);
92+
return new DfsQueryPhase(queryPhaseResultConsumer, client, this);
9793
}
9894

9995
@Override

server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.search.SearchPhaseResult;
2828
import org.elasticsearch.search.SearchShardTarget;
2929
import org.elasticsearch.search.builder.SearchSourceBuilder;
30+
import org.elasticsearch.search.dfs.AggregatedDfs;
3031
import org.elasticsearch.search.dfs.DfsKnnResults;
3132
import org.elasticsearch.search.dfs.DfsSearchResult;
3233
import org.elasticsearch.search.internal.AliasFilter;
@@ -319,9 +320,13 @@ private static DfsQueryPhase makeDfsPhase(
319320
MockSearchPhaseContext mockSearchPhaseContext,
320321
AtomicReference<AtomicArray<SearchPhaseResult>> responseRef
321322
) {
322-
return new DfsQueryPhase(results.asList(), null, consumer, null, mockSearchPhaseContext) {
323+
int shards = mockSearchPhaseContext.numShards;
324+
for (int i = 0; i < shards; i++) {
325+
mockSearchPhaseContext.results.getAtomicArray().set(i, results.get(i));
326+
}
327+
return new DfsQueryPhase(consumer, null, mockSearchPhaseContext) {
323328
@Override
324-
protected SearchPhase nextPhase() {
329+
protected SearchPhase nextPhase(AggregatedDfs dfs) {
325330
return new SearchPhase("test") {
326331
@Override
327332
public void run() {
@@ -342,7 +347,7 @@ public void testRewriteShardSearchRequestWithRank() {
342347
);
343348
MockSearchPhaseContext mspc = new MockSearchPhaseContext(2);
344349
mspc.searchTransport = new SearchTransportService(null, null, null);
345-
DfsQueryPhase dqp = new DfsQueryPhase(List.of(), dkrs, mock(QueryPhaseResultConsumer.class), null, mspc);
350+
DfsQueryPhase dqp = new DfsQueryPhase(mock(QueryPhaseResultConsumer.class), null, mspc);
346351

347352
QueryBuilder bm25 = new TermQueryBuilder("field", "term");
348353
SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25)
@@ -356,7 +361,7 @@ public void testRewriteShardSearchRequestWithRank() {
356361
SearchRequest sr = new SearchRequest().allowPartialSearchResults(true).source(ssb);
357362
ShardSearchRequest ssr = new ShardSearchRequest(null, sr, new ShardId("test", "testuuid", 1), 1, 1, null, 1.0f, 0, null);
358363

359-
dqp.rewriteShardSearchRequest(ssr);
364+
dqp.rewriteShardSearchRequest(dkrs, ssr);
360365

361366
KnnScoreDocQueryBuilder ksdqb0 = new KnnScoreDocQueryBuilder(
362367
new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1) },

0 commit comments

Comments
 (0)