@@ -90,15 +90,13 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
9090 private final Object shardFailuresMutex = new Object ();
9191 private final AtomicBoolean hasShardResponse = new AtomicBoolean (false );
9292 private final AtomicInteger successfulOps = new AtomicInteger ();
93- private final AtomicInteger skippedOps = new AtomicInteger ();
9493 private final SearchTimeProvider timeProvider ;
9594 private final SearchResponse .Clusters clusters ;
9695
9796 protected final GroupShardsIterator <SearchShardIterator > toSkipShardsIts ;
9897 protected final GroupShardsIterator <SearchShardIterator > shardsIts ;
9998 private final SearchShardIterator [] shardIterators ;
100- private final int expectedTotalOps ;
101- private final AtomicInteger totalOps = new AtomicInteger ();
99+ private final AtomicInteger outstandingShards ;
102100 private final int maxConcurrentRequestsPerNode ;
103101 private final Map <String , PendingExecutions > pendingExecutionsPerNode = new ConcurrentHashMap <>();
104102 private final boolean throttleConcurrentRequests ;
@@ -139,18 +137,12 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
139137 }
140138 this .toSkipShardsIts = new GroupShardsIterator <>(toSkipIterators );
141139 this .shardsIts = new GroupShardsIterator <>(iterators );
142-
140+ outstandingShards = new AtomicInteger ( shardsIts . size ());
143141 this .shardIterators = iterators .toArray (new SearchShardIterator [0 ]);
144142 // we later compute the shard index based on the natural order of the shards
145143 // that participate in the search request. This means that this number is
146144 // consistent between two requests that target the same shards.
147145 Arrays .sort (shardIterators );
148-
149- // we need to add 1 for non active partition, since we count it in the total. This means for each shard in the iterator we sum up
150- // it's number of active shards but use 1 as the default if no replica of a shard is active at this point.
151- // on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
152- // we process hence we add one for the non active partition here.
153- this .expectedTotalOps = shardsIts .totalSizeWith1ForEmpty ();
154146 this .maxConcurrentRequestsPerNode = maxConcurrentRequestsPerNode ;
155147 // in the case were we have less shards than maxConcurrentRequestsPerNode we don't need to throttle
156148 this .throttleConcurrentRequests = maxConcurrentRequestsPerNode < shardsIts .size ();
@@ -251,9 +243,8 @@ protected final void run() {
251243
252244 void skipShard (SearchShardIterator iterator ) {
253245 successfulOps .incrementAndGet ();
254- skippedOps .incrementAndGet ();
255246 assert iterator .skip ();
256- successfulShardExecution (iterator );
247+ successfulShardExecution ();
257248 }
258249
259250 private static boolean assertExecuteOnStartThread () {
@@ -380,7 +371,7 @@ protected void executeNextPhase(String currentPhase, Supplier<SearchPhase> nextP
380371 "Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})" ,
381372 discrepancy ,
382373 successfulOps .get (),
383- skippedOps . get (),
374+ toSkipShardsIts . size (),
384375 getNumShards (),
385376 currentPhase
386377 );
@@ -449,17 +440,14 @@ private void onShardFailure(final int shardIndex, SearchShardTarget shard, final
449440 }
450441 onShardGroupFailure (shardIndex , shard , e );
451442 }
452- final int totalOps = this .totalOps .incrementAndGet ();
453- if (totalOps == expectedTotalOps ) {
454- onPhaseDone ();
455- } else if (totalOps > expectedTotalOps ) {
456- throw new AssertionError (
457- "unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]" ,
458- new SearchPhaseExecutionException (getName (), "Shard failures" , null , buildShardFailures ())
459- );
443+ if (lastShard == false ) {
444+ performPhaseOnShard (shardIndex , shardIt , nextShard );
460445 } else {
461- if (lastShard == false ) {
462- performPhaseOnShard (shardIndex , shardIt , nextShard );
446+ // count down outstanding shards, we're done with this shard as there's no more copies to try
447+ final int outstanding = outstandingShards .decrementAndGet ();
448+ assert outstanding >= 0 : "outstanding: " + outstanding ;
449+ if (outstanding == 0 ) {
450+ onPhaseDone ();
463451 }
464452 }
465453 }
@@ -535,10 +523,10 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) {
535523 if (logger .isTraceEnabled ()) {
536524 logger .trace ("got first-phase result from {}" , result != null ? result .getSearchShardTarget () : null );
537525 }
538- results .consumeResult (result , () -> onShardResultConsumed (result , shardIt ));
526+ results .consumeResult (result , () -> onShardResultConsumed (result ));
539527 }
540528
541- private void onShardResultConsumed (Result result , SearchShardIterator shardIt ) {
529+ private void onShardResultConsumed (Result result ) {
542530 successfulOps .incrementAndGet ();
543531 // clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
544532 // so its ok concurrency wise to miss potentially the shard failures being created because of another failure
@@ -552,28 +540,14 @@ private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
552540 // cause the successor to read a wrong value from successfulOps if second phase is very fast ie. count etc.
553541 // increment all the "future" shards to update the total ops since we some may work and some may not...
554542 // and when that happens, we break on total ops, so we must maintain them
555- successfulShardExecution (shardIt );
543+ successfulShardExecution ();
556544 }
557545
558- private void successfulShardExecution (SearchShardIterator shardsIt ) {
559- final int remainingOpsOnIterator ;
560- if (shardsIt .skip ()) {
561- // It's possible that we're skipping a shard that's unavailable
562- // but its range was available in the IndexMetadata, in that
563- // case the shardsIt.remaining() would be 0, expectedTotalOps
564- // accounts for unavailable shards too.
565- remainingOpsOnIterator = Math .max (shardsIt .remaining (), 1 );
566- } else {
567- remainingOpsOnIterator = shardsIt .remaining () + 1 ;
568- }
569- final int xTotalOps = totalOps .addAndGet (remainingOpsOnIterator );
570- if (xTotalOps == expectedTotalOps ) {
546+ private void successfulShardExecution () {
547+ final int outstanding = outstandingShards .decrementAndGet ();
548+ assert outstanding >= 0 : "outstanding: " + outstanding ;
549+ if (outstanding == 0 ) {
571550 onPhaseDone ();
572- } else if (xTotalOps > expectedTotalOps ) {
573- throw new AssertionError (
574- "unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]" ,
575- new SearchPhaseExecutionException (getName (), "Shard failures" , null , buildShardFailures ())
576- );
577551 }
578552 }
579553
@@ -640,7 +614,7 @@ private SearchResponse buildSearchResponse(
640614 scrollId ,
641615 getNumShards (),
642616 numSuccess ,
643- skippedOps . get (),
617+ toSkipShardsIts . size (),
644618 buildTookInMillis (),
645619 failures ,
646620 clusters ,
0 commit comments