4040 */
4141class DfsQueryPhase extends SearchPhase {
4242 private final SearchPhaseResults <SearchPhaseResult > queryResult ;
43- private final List <DfsSearchResult > searchResults ;
44- private final AggregatedDfs dfs ;
45- private final List <DfsKnnResults > knnResults ;
4643 private final Client client ;
4744 private final AbstractSearchAsyncAction <?> context ;
48- private final SearchTransportService searchTransportService ;
4945 private final SearchProgressListener progressListener ;
5046
51- DfsQueryPhase (
52- List <DfsSearchResult > searchResults ,
53- List <DfsKnnResults > knnResults ,
54- SearchPhaseResults <SearchPhaseResult > queryResult ,
55- Client client ,
56- AbstractSearchAsyncAction <?> context
57- ) {
47+ DfsQueryPhase (SearchPhaseResults <SearchPhaseResult > queryResult , Client client , AbstractSearchAsyncAction <?> context ) {
5848 super ("dfs_query" );
5949 this .progressListener = context .getTask ().getProgressListener ();
6050 this .queryResult = queryResult ;
61- this .searchResults = searchResults ;
62- this .dfs = SearchPhaseController .aggregateDfs (searchResults );
63- this .knnResults = knnResults ;
6451 this .client = client ;
6552 this .context = context ;
66- this .searchTransportService = context .getSearchTransport ();
6753 }
6854
6955 // protected for testing
70- protected SearchPhase nextPhase () {
56+ protected SearchPhase nextPhase (AggregatedDfs dfs ) {
7157 return SearchQueryThenFetchAsyncAction .nextPhase (client , context , queryResult , dfs );
7258 }
7359
60+ @ SuppressWarnings ("unchecked" )
7461 @ Override
7562 public void run () {
63+ List <DfsSearchResult > searchResults = (List <DfsSearchResult >) context .results .getAtomicArray ().asList ();
64+ AggregatedDfs dfs = SearchPhaseController .aggregateDfs (searchResults );
7665 // TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
7766 // to free up memory early
7867 final CountedCollector <SearchPhaseResult > counter = new CountedCollector <>(
7968 queryResult ,
8069 searchResults .size (),
81- () -> context .executeNextPhase (this , this :: nextPhase ),
70+ () -> context .executeNextPhase (this , () -> nextPhase ( dfs ) ),
8271 context
8372 );
8473
74+ List <DfsKnnResults > knnResults = SearchPhaseController .mergeKnnResults (context .getRequest (), searchResults );
8575 for (final DfsSearchResult dfsResult : searchResults ) {
8676 final SearchShardTarget shardTarget = dfsResult .getSearchShardTarget ();
8777 final int shardIndex = dfsResult .getShardIndex ();
8878 QuerySearchRequest querySearchRequest = new QuerySearchRequest (
8979 context .getOriginalIndices (shardIndex ),
9080 dfsResult .getContextId (),
91- rewriteShardSearchRequest (dfsResult .getShardSearchRequest ()),
81+ rewriteShardSearchRequest (knnResults , dfsResult .getShardSearchRequest ()),
9282 dfs
9383 );
9484 final Transport .Connection connection ;
@@ -98,11 +88,8 @@ public void run() {
9888 shardFailure (e , querySearchRequest , shardIndex , shardTarget , counter );
9989 continue ;
10090 }
101- searchTransportService .sendExecuteQuery (
102- connection ,
103- querySearchRequest ,
104- context .getTask (),
105- new SearchActionListener <>(shardTarget , shardIndex ) {
91+ context .getSearchTransport ()
92+ .sendExecuteQuery (connection , querySearchRequest , context .getTask (), new SearchActionListener <>(shardTarget , shardIndex ) {
10693
10794 @ Override
10895 protected void innerOnResponse (QuerySearchResult response ) {
@@ -127,8 +114,7 @@ public void onFailure(Exception exception) {
127114 }
128115 }
129116 }
130- }
131- );
117+ });
132118 }
133119 }
134120
@@ -145,7 +131,7 @@ private void shardFailure(
145131 }
146132
147133 // package private for testing
148- ShardSearchRequest rewriteShardSearchRequest (ShardSearchRequest request ) {
134+ ShardSearchRequest rewriteShardSearchRequest (List < DfsKnnResults > knnResults , ShardSearchRequest request ) {
149135 SearchSourceBuilder source = request .source ();
150136 if (source == null || source .knnSearch ().isEmpty ()) {
151137 return request ;
0 commit comments