@@ -91,15 +91,13 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
9191 private final Object shardFailuresMutex = new Object ();
9292 private final AtomicBoolean hasShardResponse = new AtomicBoolean (false );
9393 private final AtomicInteger successfulOps = new AtomicInteger ();
94- private final AtomicInteger skippedOps = new AtomicInteger ();
9594 private final SearchTimeProvider timeProvider ;
9695 private final SearchResponse .Clusters clusters ;
9796
9897 protected final GroupShardsIterator <SearchShardIterator > toSkipShardsIts ;
9998 protected final GroupShardsIterator <SearchShardIterator > shardsIts ;
10099 private final SearchShardIterator [] shardIterators ;
101- private final int expectedTotalOps ;
102- private final AtomicInteger totalOps = new AtomicInteger ();
100+ private final AtomicInteger outstandingShards ;
103101 private final int maxConcurrentRequestsPerNode ;
104102 private final Map <String , PendingExecutions > pendingExecutionsPerNode = new ConcurrentHashMap <>();
105103 private final boolean throttleConcurrentRequests ;
@@ -140,18 +138,12 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
140138 }
141139 this .toSkipShardsIts = new GroupShardsIterator <>(toSkipIterators );
142140 this .shardsIts = new GroupShardsIterator <>(iterators );
143-
141+ outstandingShards = new AtomicInteger ( shardsIts . size ());
144142 this .shardIterators = iterators .toArray (new SearchShardIterator [0 ]);
145143 // we later compute the shard index based on the natural order of the shards
146144 // that participate in the search request. This means that this number is
147145 // consistent between two requests that target the same shards.
148146 Arrays .sort (shardIterators );
149-
150- // we need to add 1 for non active partition, since we count it in the total. This means for each shard in the iterator we sum up
151- // it's number of active shards but use 1 as the default if no replica of a shard is active at this point.
152- // on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
153- // we process hence we add one for the non active partition here.
154- this .expectedTotalOps = shardsIts .totalSizeWith1ForEmpty ();
155147 this .maxConcurrentRequestsPerNode = maxConcurrentRequestsPerNode ;
156148 // in the case were we have less shards than maxConcurrentRequestsPerNode we don't need to throttle
157149 this .throttleConcurrentRequests = maxConcurrentRequestsPerNode < shardsIts .size ();
@@ -261,9 +253,8 @@ public final void run() {
261253
262254 void skipShard (SearchShardIterator iterator ) {
263255 successfulOps .incrementAndGet ();
264- skippedOps .incrementAndGet ();
265256 assert iterator .skip ();
266- successfulShardExecution (iterator );
257+ successfulShardExecution ();
267258 }
268259
269260 private boolean checkMinimumVersion (GroupShardsIterator <SearchShardIterator > shardsIts ) {
@@ -405,7 +396,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
405396 "Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})" ,
406397 discrepancy ,
407398 successfulOps .get (),
408- skippedOps . get (),
399+ toSkipShardsIts . size (),
409400 getNumShards (),
410401 currentPhase .getName ()
411402 );
@@ -474,17 +465,14 @@ private void onShardFailure(final int shardIndex, SearchShardTarget shard, final
474465 }
475466 onShardGroupFailure (shardIndex , shard , e );
476467 }
477- final int totalOps = this .totalOps .incrementAndGet ();
478- if (totalOps == expectedTotalOps ) {
479- onPhaseDone ();
480- } else if (totalOps > expectedTotalOps ) {
481- throw new AssertionError (
482- "unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]" ,
483- new SearchPhaseExecutionException (getName (), "Shard failures" , null , buildShardFailures ())
484- );
468+ if (lastShard == false ) {
469+ performPhaseOnShard (shardIndex , shardIt , nextShard );
485470 } else {
486- if (lastShard == false ) {
487- performPhaseOnShard (shardIndex , shardIt , nextShard );
471+ // count down outstanding shards, we're done with this shard as there's no more copies to try
472+ final int outstanding = outstandingShards .decrementAndGet ();
473+ assert outstanding >= 0 : "outstanding: " + outstanding ;
474+ if (outstanding == 0 ) {
475+ onPhaseDone ();
488476 }
489477 }
490478 }
@@ -561,10 +549,10 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) {
561549 if (logger .isTraceEnabled ()) {
562550 logger .trace ("got first-phase result from {}" , result != null ? result .getSearchShardTarget () : null );
563551 }
564- results .consumeResult (result , () -> onShardResultConsumed (result , shardIt ));
552+ results .consumeResult (result , () -> onShardResultConsumed (result ));
565553 }
566554
567- private void onShardResultConsumed (Result result , SearchShardIterator shardIt ) {
555+ private void onShardResultConsumed (Result result ) {
568556 successfulOps .incrementAndGet ();
569557 // clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
570558 // so its ok concurrency wise to miss potentially the shard failures being created because of another failure
@@ -578,28 +566,14 @@ private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
578566 // cause the successor to read a wrong value from successfulOps if second phase is very fast ie. count etc.
579567 // increment all the "future" shards to update the total ops since we some may work and some may not...
580568 // and when that happens, we break on total ops, so we must maintain them
581- successfulShardExecution (shardIt );
569+ successfulShardExecution ();
582570 }
583571
584- private void successfulShardExecution (SearchShardIterator shardsIt ) {
585- final int remainingOpsOnIterator ;
586- if (shardsIt .skip ()) {
587- // It's possible that we're skipping a shard that's unavailable
588- // but its range was available in the IndexMetadata, in that
589- // case the shardsIt.remaining() would be 0, expectedTotalOps
590- // accounts for unavailable shards too.
591- remainingOpsOnIterator = Math .max (shardsIt .remaining (), 1 );
592- } else {
593- remainingOpsOnIterator = shardsIt .remaining () + 1 ;
594- }
595- final int xTotalOps = totalOps .addAndGet (remainingOpsOnIterator );
596- if (xTotalOps == expectedTotalOps ) {
572+ private void successfulShardExecution () {
573+ final int outstanding = outstandingShards .decrementAndGet ();
574+ assert outstanding >= 0 : "outstanding: " + outstanding ;
575+ if (outstanding == 0 ) {
597576 onPhaseDone ();
598- } else if (xTotalOps > expectedTotalOps ) {
599- throw new AssertionError (
600- "unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]" ,
601- new SearchPhaseExecutionException (getName (), "Shard failures" , null , buildShardFailures ())
602- );
603577 }
604578 }
605579
@@ -666,7 +640,7 @@ private SearchResponse buildSearchResponse(
666640 scrollId ,
667641 getNumShards (),
668642 numSuccess ,
669- skippedOps . get (),
643+ toSkipShardsIts . size (),
670644 buildTookInMillis (),
671645 failures ,
672646 clusters ,
0 commit comments