4242import org .elasticsearch .tasks .TaskCancelledException ;
4343import org .elasticsearch .transport .Transport ;
4444
45+ import java .lang .invoke .MethodHandles ;
46+ import java .lang .invoke .VarHandle ;
4547import java .util .ArrayList ;
4648import java .util .Arrays ;
4749import java .util .List ;
@@ -90,15 +92,25 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
9092 private final Object shardFailuresMutex = new Object ();
9193 private final AtomicBoolean hasShardResponse = new AtomicBoolean (false );
9294 private final AtomicInteger successfulOps = new AtomicInteger ();
93- private final AtomicInteger skippedOps = new AtomicInteger ();
9495 private final SearchTimeProvider timeProvider ;
9596 private final SearchResponse .Clusters clusters ;
9697
9798 protected final GroupShardsIterator <SearchShardIterator > toSkipShardsIts ;
9899 protected final GroupShardsIterator <SearchShardIterator > shardsIts ;
99100 private final SearchShardIterator [] shardIterators ;
100- private final int expectedTotalOps ;
101- private final AtomicInteger totalOps = new AtomicInteger ();
101+
102+ private static final VarHandle OUTSTANDING_SHARDS ;
103+
104+ static {
105+ try {
106+ OUTSTANDING_SHARDS = MethodHandles .lookup ().findVarHandle (AbstractSearchAsyncAction .class , "outstandingShards" , int .class );
107+ } catch (Exception e ) {
108+ throw new ExceptionInInitializerError (e );
109+ }
110+ }
111+
112+ @ SuppressWarnings ("unused" ) // only accessed via #OUTSTANDING_SHARDS
113+ private int outstandingShards ;
102114 private final int maxConcurrentRequestsPerNode ;
103115 private final Map <String , PendingExecutions > pendingExecutionsPerNode = new ConcurrentHashMap <>();
104116 private final boolean throttleConcurrentRequests ;
@@ -139,18 +151,12 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
139151 }
140152 this .toSkipShardsIts = new GroupShardsIterator <>(toSkipIterators );
141153 this .shardsIts = new GroupShardsIterator <>(iterators );
142-
154+ OUTSTANDING_SHARDS . setRelease ( this , shardsIts . size ());
143155 this .shardIterators = iterators .toArray (new SearchShardIterator [0 ]);
144156 // we later compute the shard index based on the natural order of the shards
145157 // that participate in the search request. This means that this number is
146158 // consistent between two requests that target the same shards.
147159 Arrays .sort (shardIterators );
148-
149- // we need to add 1 for non active partition, since we count it in the total. This means for each shard in the iterator we sum up
150- // it's number of active shards but use 1 as the default if no replica of a shard is active at this point.
151- // on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
152- // we process hence we add one for the non active partition here.
153- this .expectedTotalOps = shardsIts .totalSizeWith1ForEmpty ();
154160 this .maxConcurrentRequestsPerNode = maxConcurrentRequestsPerNode ;
155161 // in the case were we have less shards than maxConcurrentRequestsPerNode we don't need to throttle
156162 this .throttleConcurrentRequests = maxConcurrentRequestsPerNode < shardsIts .size ();
@@ -251,9 +257,8 @@ protected final void run() {
251257
252258 void skipShard (SearchShardIterator iterator ) {
253259 successfulOps .incrementAndGet ();
254- skippedOps .incrementAndGet ();
255260 assert iterator .skip ();
256- successfulShardExecution (iterator );
261+ successfulShardExecution ();
257262 }
258263
259264 private static boolean assertExecuteOnStartThread () {
@@ -380,7 +385,7 @@ protected void executeNextPhase(String currentPhase, Supplier<SearchPhase> nextP
380385 "Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})" ,
381386 discrepancy ,
382387 successfulOps .get (),
383- skippedOps . get (),
388+ toSkipShardsIts . size (),
384389 getNumShards (),
385390 currentPhase
386391 );
@@ -449,17 +454,11 @@ private void onShardFailure(final int shardIndex, SearchShardTarget shard, final
449454 }
450455 onShardGroupFailure (shardIndex , shard , e );
451456 }
452- final int totalOps = this .totalOps .incrementAndGet ();
453- if (totalOps == expectedTotalOps ) {
454- onPhaseDone ();
455- } else if (totalOps > expectedTotalOps ) {
456- throw new AssertionError (
457- "unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]" ,
458- new SearchPhaseExecutionException (getName (), "Shard failures" , null , buildShardFailures ())
459- );
457+ if (lastShard == false ) {
458+ performPhaseOnShard (shardIndex , shardIt , nextShard );
460459 } else {
461- if (lastShard == false ) {
462- performPhaseOnShard ( shardIndex , shardIt , nextShard );
460+ if (( int ) OUTSTANDING_SHARDS . getAndAdd ( this , - 1 ) == 1 ) {
461+ onPhaseDone ( );
463462 }
464463 }
465464 }
@@ -535,10 +534,10 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) {
535534 if (logger .isTraceEnabled ()) {
536535 logger .trace ("got first-phase result from {}" , result != null ? result .getSearchShardTarget () : null );
537536 }
538- results .consumeResult (result , () -> onShardResultConsumed (result , shardIt ));
537+ results .consumeResult (result , () -> onShardResultConsumed (result ));
539538 }
540539
541- private void onShardResultConsumed (Result result , SearchShardIterator shardIt ) {
540+ private void onShardResultConsumed (Result result ) {
542541 successfulOps .incrementAndGet ();
543542 // clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
544543 // so its ok concurrency wise to miss potentially the shard failures being created because of another failure
@@ -552,28 +551,12 @@ private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
552551 // cause the successor to read a wrong value from successfulOps if second phase is very fast ie. count etc.
553552 // increment all the "future" shards to update the total ops since we some may work and some may not...
554553 // and when that happens, we break on total ops, so we must maintain them
555- successfulShardExecution (shardIt );
554+ successfulShardExecution ();
556555 }
557556
558- private void successfulShardExecution (SearchShardIterator shardsIt ) {
559- final int remainingOpsOnIterator ;
560- if (shardsIt .skip ()) {
561- // It's possible that we're skipping a shard that's unavailable
562- // but its range was available in the IndexMetadata, in that
563- // case the shardsIt.remaining() would be 0, expectedTotalOps
564- // accounts for unavailable shards too.
565- remainingOpsOnIterator = Math .max (shardsIt .remaining (), 1 );
566- } else {
567- remainingOpsOnIterator = shardsIt .remaining () + 1 ;
568- }
569- final int xTotalOps = totalOps .addAndGet (remainingOpsOnIterator );
570- if (xTotalOps == expectedTotalOps ) {
557+ private void successfulShardExecution () {
558+ if ((int ) OUTSTANDING_SHARDS .getAndAdd (this , -1 ) == 1 ) {
571559 onPhaseDone ();
572- } else if (xTotalOps > expectedTotalOps ) {
573- throw new AssertionError (
574- "unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]" ,
575- new SearchPhaseExecutionException (getName (), "Shard failures" , null , buildShardFailures ())
576- );
577560 }
578561 }
579562
@@ -640,7 +623,7 @@ private SearchResponse buildSearchResponse(
640623 scrollId ,
641624 getNumShards (),
642625 numSuccess ,
643- skippedOps . get (),
626+ toSkipShardsIts . size (),
644627 buildTookInMillis (),
645628 failures ,
646629 clusters ,
0 commit comments