1212import org .apache .lucene .search .ScoreDoc ;
1313import org .elasticsearch .common .util .concurrent .AbstractRunnable ;
1414import org .elasticsearch .common .util .concurrent .AtomicArray ;
15+ import org .elasticsearch .core .Nullable ;
1516import org .elasticsearch .search .SearchPhaseResult ;
1617import org .elasticsearch .search .SearchShardTarget ;
1718import org .elasticsearch .search .dfs .AggregatedDfs ;
@@ -39,13 +40,15 @@ final class FetchSearchPhase extends SearchPhase {
3940 private final Logger logger ;
4041 private final SearchProgressListener progressListener ;
4142 private final AggregatedDfs aggregatedDfs ;
43+ @ Nullable
44+ private final SearchPhaseResults <SearchPhaseResult > resultConsumer ;
4245 private final SearchPhaseController .ReducedQueryPhase reducedQueryPhase ;
4346
4447 FetchSearchPhase (
4548 SearchPhaseResults <SearchPhaseResult > resultConsumer ,
4649 AggregatedDfs aggregatedDfs ,
4750 SearchPhaseContext context ,
48- SearchPhaseController .ReducedQueryPhase reducedQueryPhase
51+ @ Nullable SearchPhaseController .ReducedQueryPhase reducedQueryPhase
4952 ) {
5053 this (
5154 resultConsumer ,
@@ -64,7 +67,7 @@ final class FetchSearchPhase extends SearchPhase {
6467 SearchPhaseResults <SearchPhaseResult > resultConsumer ,
6568 AggregatedDfs aggregatedDfs ,
6669 SearchPhaseContext context ,
67- SearchPhaseController .ReducedQueryPhase reducedQueryPhase ,
70+ @ Nullable SearchPhaseController .ReducedQueryPhase reducedQueryPhase ,
6871 BiFunction <SearchResponseSections , AtomicArray <SearchPhaseResult >, SearchPhase > nextPhaseFactory
6972 ) {
7073 super ("fetch" );
@@ -85,14 +88,15 @@ final class FetchSearchPhase extends SearchPhase {
8588 this .logger = context .getLogger ();
8689 this .progressListener = context .getTask ().getProgressListener ();
8790 this .reducedQueryPhase = reducedQueryPhase ;
91+ this .resultConsumer = reducedQueryPhase == null ? resultConsumer : null ;
8892 }
8993
9094 @ Override
9195 public void run () {
9296 context .execute (new AbstractRunnable () {
9397
9498 @ Override
95- protected void doRun () {
99+ protected void doRun () throws Exception {
96100 innerRun ();
97101 }
98102
@@ -103,7 +107,10 @@ public void onFailure(Exception e) {
103107 });
104108 }
105109
106- private void innerRun () {
110+ private void innerRun () throws Exception {
111+ assert this .reducedQueryPhase == null ^ this .resultConsumer == null ;
112+ // depending on whether we executed the RankFeaturePhase we may or may not have the reduced query result computed already
113+ final var reducedQueryPhase = this .reducedQueryPhase == null ? resultConsumer .reduce () : this .reducedQueryPhase ;
107114 final int numShards = context .getNumShards ();
108115 // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might
109116 // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase.
@@ -113,15 +120,15 @@ private void innerRun() {
113120 if (queryAndFetchOptimization ) {
114121 assert assertConsistentWithQueryAndFetchOptimization ();
115122 // query AND fetch optimization
116- moveToNextPhase (searchPhaseShardResults );
123+ moveToNextPhase (searchPhaseShardResults , reducedQueryPhase );
117124 } else {
118125 ScoreDoc [] scoreDocs = reducedQueryPhase .sortedTopDocs ().scoreDocs ();
119126 // no docs to fetch -- sidestep everything and return
120127 if (scoreDocs .length == 0 ) {
121128 // we have to release contexts here to free up resources
122129 searchPhaseShardResults .asList ()
123130 .forEach (searchPhaseShardResult -> releaseIrrelevantSearchContext (searchPhaseShardResult , context ));
124- moveToNextPhase (fetchResults .getAtomicArray ());
131+ moveToNextPhase (fetchResults .getAtomicArray (), reducedQueryPhase );
125132 } else {
126133 final boolean shouldExplainRank = shouldExplainRankScores (context .getRequest ());
127134 final List <Map <Integer , RankDoc >> rankDocsPerShard = false == shouldExplainRank
@@ -134,7 +141,7 @@ private void innerRun() {
134141 final CountedCollector <FetchSearchResult > counter = new CountedCollector <>(
135142 fetchResults ,
136143 docIdsToLoad .length , // we count down every shard in the result no matter if we got any results or not
137- () -> moveToNextPhase (fetchResults .getAtomicArray ()),
144+ () -> moveToNextPhase (fetchResults .getAtomicArray (), reducedQueryPhase ),
138145 context
139146 );
140147 for (int i = 0 ; i < docIdsToLoad .length ; i ++) {
@@ -243,7 +250,10 @@ public void onFailure(Exception e) {
243250 );
244251 }
245252
246- private void moveToNextPhase (AtomicArray <? extends SearchPhaseResult > fetchResultsArr ) {
253+ private void moveToNextPhase (
254+ AtomicArray <? extends SearchPhaseResult > fetchResultsArr ,
255+ SearchPhaseController .ReducedQueryPhase reducedQueryPhase
256+ ) {
247257 var resp = SearchPhaseController .merge (context .getRequest ().scroll () != null , reducedQueryPhase , fetchResultsArr );
248258 context .addReleasable (resp ::decRef );
249259 fetchResults .close ();
0 commit comments