3636import java .util .concurrent .Executor ;
3737import java .util .concurrent .atomic .AtomicReference ;
3838import java .util .function .Consumer ;
39+ import java .util .stream .Collectors ;
3940
4041import static org .elasticsearch .action .search .SearchPhaseController .getTopDocsSize ;
4142import static org .elasticsearch .action .search .SearchPhaseController .mergeTopDocs ;
@@ -71,14 +72,16 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
7172 * Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results
7273 * as shard results are consumed.
7374 */
74- public QueryPhaseResultConsumer (SearchRequest request ,
75- Executor executor ,
76- CircuitBreaker circuitBreaker ,
77- SearchPhaseController controller ,
78- SearchProgressListener progressListener ,
79- NamedWriteableRegistry namedWriteableRegistry ,
80- int expectedResultSize ,
81- Consumer <Exception > onPartialMergeFailure ) {
75+ public QueryPhaseResultConsumer (
76+ SearchRequest request ,
77+ Executor executor ,
78+ CircuitBreaker circuitBreaker ,
79+ SearchPhaseController controller ,
80+ SearchProgressListener progressListener ,
81+ NamedWriteableRegistry namedWriteableRegistry ,
82+ int expectedResultSize ,
83+ Consumer <Exception > onPartialMergeFailure
84+ ) {
8285 super (expectedResultSize );
8386 this .executor = executor ;
8487 this .circuitBreaker = circuitBreaker ;
@@ -93,7 +96,7 @@ public QueryPhaseResultConsumer(SearchRequest request,
9396 SearchSourceBuilder source = request .source ();
9497 this .hasTopDocs = source == null || source .size () != 0 ;
9598 this .hasAggs = source != null && source .aggregations () != null ;
96- int batchReduceSize = (hasAggs || hasTopDocs ) ? Math .min (request .getBatchedReduceSize (), expectedResultSize ) : expectedResultSize ;
99+ int batchReduceSize = (hasAggs || hasTopDocs ) ? Math .min (request .getBatchedReduceSize (), expectedResultSize ) : expectedResultSize ;
97100 this .pendingMerges = new PendingMerges (batchReduceSize , request .resolveTrackTotalHitsUpTo ());
98101 }
99102
@@ -128,28 +131,41 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
128131 // Add an estimate of the final reduce size
129132 breakerSize = pendingMerges .addEstimateAndMaybeBreak (pendingMerges .estimateRamBytesUsedForReduce (breakerSize ));
130133 }
131- SearchPhaseController .ReducedQueryPhase reducePhase = controller .reducedQueryPhase (results .asList (), aggsList ,
132- topDocsList , topDocsStats , pendingMerges .numReducePhases , false , aggReduceContextBuilder , performFinalReduce );
134+ SearchPhaseController .ReducedQueryPhase reducePhase = controller .reducedQueryPhase (
135+ results .asList (),
136+ aggsList ,
137+ topDocsList ,
138+ topDocsStats ,
139+ pendingMerges .numReducePhases ,
140+ false ,
141+ aggReduceContextBuilder ,
142+ performFinalReduce
143+ );
133144 if (hasAggs
134- // reduced aggregations can be null if all shards failed
135- && reducePhase .aggregations != null ) {
145+ // reduced aggregations can be null if all shards failed
146+ && reducePhase .aggregations != null ) {
136147
137148 // Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result
138149 long finalSize = DelayableWriteable .getSerializedSize (reducePhase .aggregations ) - breakerSize ;
139150 pendingMerges .addWithoutBreaking (finalSize );
140- logger .trace ("aggs final reduction [{}] max [{}]" ,
141- pendingMerges .aggsCurrentBufferSize , pendingMerges .maxAggsCurrentBufferSize );
151+ logger .trace ("aggs final reduction [{}] max [{}]" , pendingMerges .aggsCurrentBufferSize , pendingMerges .maxAggsCurrentBufferSize );
142152 }
143- progressListener .notifyFinalReduce (SearchProgressListener .buildSearchShards (results .asList ()),
144- reducePhase .totalHits , reducePhase .aggregations , reducePhase .numReducePhases );
153+ progressListener .notifyFinalReduce (
154+ SearchProgressListener .buildSearchShards (results .asList ()),
155+ reducePhase .totalHits ,
156+ reducePhase .aggregations ,
157+ reducePhase .numReducePhases
158+ );
145159 return reducePhase ;
146160 }
147161
148- private MergeResult partialReduce (QuerySearchResult [] toConsume ,
149- List <SearchShard > emptyResults ,
150- TopDocsStats topDocsStats ,
151- MergeResult lastMerge ,
152- int numReducePhases ) {
162+ private MergeResult partialReduce (
163+ QuerySearchResult [] toConsume ,
164+ List <SearchShard > emptyResults ,
165+ TopDocsStats topDocsStats ,
166+ MergeResult lastMerge ,
167+ int numReducePhases
168+ ) {
153169 // ensure consistent ordering
154170 Arrays .sort (toConsume , Comparator .comparingInt (QuerySearchResult ::getShardIndex ));
155171
@@ -168,9 +184,12 @@ private MergeResult partialReduce(QuerySearchResult[] toConsume,
168184 setShardIndex (topDocs .topDocs , result .getShardIndex ());
169185 topDocsList .add (topDocs .topDocs );
170186 }
171- newTopDocs = mergeTopDocs (topDocsList ,
187+ newTopDocs = mergeTopDocs (
188+ topDocsList ,
172189 // we have to merge here in the same way we collect on a shard
173- topNSize , 0 );
190+ topNSize ,
191+ 0
192+ );
174193 } else {
175194 newTopDocs = null ;
176195 }
@@ -233,14 +252,24 @@ private class PendingMerges implements Releasable {
233252
234253 @ Override
235254 public synchronized void close () {
236- assert hasPendingMerges () == false : "cannot close with partial reduce in-flight" ;
237255 if (hasFailure ()) {
238256 assert circuitBreakerBytes == 0 ;
239- return ;
257+ } else {
258+ assert circuitBreakerBytes >= 0 ;
259+ }
260+
261+ List <Releasable > toRelease = new ArrayList <>(buffer .stream ().<Releasable >map (b -> b ::releaseAggs ).collect (Collectors .toList ()));
262+ toRelease .add (() -> {
263+ circuitBreaker .addWithoutBreaking (-circuitBreakerBytes );
264+ circuitBreakerBytes = 0 ;
265+ });
266+
267+ Releasables .close (toRelease );
268+
269+ if (hasPendingMerges ()) {
270+ // This is a theoretically unreachable exception.
271+ throw new IllegalStateException ("Attempted to close with partial reduce in-flight" );
240272 }
241- assert circuitBreakerBytes >= 0 ;
242- circuitBreaker .addWithoutBreaking (-circuitBreakerBytes );
243- circuitBreakerBytes = 0 ;
244273 }
245274
246275 synchronized Exception getFailure () {
@@ -378,8 +407,12 @@ private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedS
378407 // and replace the estimation with the serialized size of the newly reduced result.
379408 long newSize = mergeResult .estimatedSize - estimatedSize ;
380409 addWithoutBreaking (newSize );
381- logger .trace ("aggs partial reduction [{}->{}] max [{}]" ,
382- estimatedSize , mergeResult .estimatedSize , maxAggsCurrentBufferSize );
410+ logger .trace (
411+ "aggs partial reduction [{}->{}] max [{}]" ,
412+ estimatedSize ,
413+ mergeResult .estimatedSize ,
414+ maxAggsCurrentBufferSize
415+ );
383416 }
384417 task .consumeListener ();
385418 }
@@ -388,9 +421,7 @@ private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedS
388421 private void tryExecuteNext () {
389422 final MergeTask task ;
390423 synchronized (this ) {
391- if (queue .isEmpty ()
392- || hasFailure ()
393- || runningTask .get () != null ) {
424+ if (queue .isEmpty () || hasFailure () || runningTask .get () != null ) {
394425 return ;
395426 }
396427 task = queue .poll ();
@@ -411,7 +442,7 @@ protected void doRun() {
411442 long estimatedMergeSize = estimateRamBytesUsedForReduce (estimatedTotalSize );
412443 addEstimateAndMaybeBreak (estimatedMergeSize );
413444 estimatedTotalSize += estimatedMergeSize ;
414- ++ numReducePhases ;
445+ ++numReducePhases ;
415446 newMerge = partialReduce (toConsume , task .emptyResults , topDocsStats , thisMergeResult , numReducePhases );
416447 } catch (Exception t ) {
417448 for (QuerySearchResult result : toConsume ) {
@@ -475,8 +506,12 @@ private static class MergeResult {
475506 private final InternalAggregations reducedAggs ;
476507 private final long estimatedSize ;
477508
478- private MergeResult (List <SearchShard > processedShards , TopDocs reducedTopDocs ,
479- InternalAggregations reducedAggs , long estimatedSize ) {
509+ private MergeResult (
510+ List <SearchShard > processedShards ,
511+ TopDocs reducedTopDocs ,
512+ InternalAggregations reducedAggs ,
513+ long estimatedSize
514+ ) {
480515 this .processedShards = processedShards ;
481516 this .reducedTopDocs = reducedTopDocs ;
482517 this .reducedAggs = reducedAggs ;
0 commit comments