Skip to content

Commit bd5c03b

Browse files
Simplify counting in AbstractSearchAsyncAction
No need to do this so complicated, just count down one when we're actually done with a specific shard id.
1 parent 5efe216 commit bd5c03b

File tree

2 files changed

+28
-46
lines changed

2 files changed

+28
-46
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
import org.elasticsearch.tasks.TaskCancelledException;
4343
import org.elasticsearch.transport.Transport;
4444

45+
import java.lang.invoke.MethodHandles;
46+
import java.lang.invoke.VarHandle;
4547
import java.util.ArrayList;
4648
import java.util.Arrays;
4749
import java.util.List;
@@ -90,15 +92,25 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
9092
private final Object shardFailuresMutex = new Object();
9193
private final AtomicBoolean hasShardResponse = new AtomicBoolean(false);
9294
private final AtomicInteger successfulOps = new AtomicInteger();
93-
private final AtomicInteger skippedOps = new AtomicInteger();
9495
private final SearchTimeProvider timeProvider;
9596
private final SearchResponse.Clusters clusters;
9697

9798
protected final GroupShardsIterator<SearchShardIterator> toSkipShardsIts;
9899
protected final GroupShardsIterator<SearchShardIterator> shardsIts;
99100
private final SearchShardIterator[] shardIterators;
100-
private final int expectedTotalOps;
101-
private final AtomicInteger totalOps = new AtomicInteger();
101+
102+
private static final VarHandle OUTSTANDING_SHARDS;
103+
104+
static {
105+
try {
106+
OUTSTANDING_SHARDS = MethodHandles.lookup().findVarHandle(AbstractSearchAsyncAction.class, "outstandingShards", int.class);
107+
} catch (Exception e) {
108+
throw new ExceptionInInitializerError(e);
109+
}
110+
}
111+
112+
@SuppressWarnings("unused") // only accessed via #OUTSTANDING_SHARDS
113+
private int outstandingShards;
102114
private final int maxConcurrentRequestsPerNode;
103115
private final Map<String, PendingExecutions> pendingExecutionsPerNode = new ConcurrentHashMap<>();
104116
private final boolean throttleConcurrentRequests;
@@ -139,18 +151,12 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
139151
}
140152
this.toSkipShardsIts = new GroupShardsIterator<>(toSkipIterators);
141153
this.shardsIts = new GroupShardsIterator<>(iterators);
142-
154+
OUTSTANDING_SHARDS.setRelease(this, shardsIts.size());
143155
this.shardIterators = iterators.toArray(new SearchShardIterator[0]);
144156
// we later compute the shard index based on the natural order of the shards
145157
// that participate in the search request. This means that this number is
146158
// consistent between two requests that target the same shards.
147159
Arrays.sort(shardIterators);
148-
149-
// we need to add 1 for non active partition, since we count it in the total. This means for each shard in the iterator we sum up
150-
// it's number of active shards but use 1 as the default if no replica of a shard is active at this point.
151-
// on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
152-
// we process hence we add one for the non active partition here.
153-
this.expectedTotalOps = shardsIts.totalSizeWith1ForEmpty();
154160
this.maxConcurrentRequestsPerNode = maxConcurrentRequestsPerNode;
155161
// in the case were we have less shards than maxConcurrentRequestsPerNode we don't need to throttle
156162
this.throttleConcurrentRequests = maxConcurrentRequestsPerNode < shardsIts.size();
@@ -251,9 +257,8 @@ protected final void run() {
251257

252258
void skipShard(SearchShardIterator iterator) {
253259
successfulOps.incrementAndGet();
254-
skippedOps.incrementAndGet();
255260
assert iterator.skip();
256-
successfulShardExecution(iterator);
261+
successfulShardExecution();
257262
}
258263

259264
private static boolean assertExecuteOnStartThread() {
@@ -380,7 +385,7 @@ protected void executeNextPhase(String currentPhase, Supplier<SearchPhase> nextP
380385
"Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})",
381386
discrepancy,
382387
successfulOps.get(),
383-
skippedOps.get(),
388+
toSkipShardsIts.size(),
384389
getNumShards(),
385390
currentPhase
386391
);
@@ -449,17 +454,11 @@ private void onShardFailure(final int shardIndex, SearchShardTarget shard, final
449454
}
450455
onShardGroupFailure(shardIndex, shard, e);
451456
}
452-
final int totalOps = this.totalOps.incrementAndGet();
453-
if (totalOps == expectedTotalOps) {
454-
onPhaseDone();
455-
} else if (totalOps > expectedTotalOps) {
456-
throw new AssertionError(
457-
"unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]",
458-
new SearchPhaseExecutionException(getName(), "Shard failures", null, buildShardFailures())
459-
);
457+
if (lastShard == false) {
458+
performPhaseOnShard(shardIndex, shardIt, nextShard);
460459
} else {
461-
if (lastShard == false) {
462-
performPhaseOnShard(shardIndex, shardIt, nextShard);
460+
if ((int) OUTSTANDING_SHARDS.getAndAdd(this, -1) == 1) {
461+
onPhaseDone();
463462
}
464463
}
465464
}
@@ -535,10 +534,10 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) {
535534
if (logger.isTraceEnabled()) {
536535
logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null);
537536
}
538-
results.consumeResult(result, () -> onShardResultConsumed(result, shardIt));
537+
results.consumeResult(result, () -> onShardResultConsumed(result));
539538
}
540539

541-
private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
540+
private void onShardResultConsumed(Result result) {
542541
successfulOps.incrementAndGet();
543542
// clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
544543
// so its ok concurrency wise to miss potentially the shard failures being created because of another failure
@@ -552,28 +551,12 @@ private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
552551
// cause the successor to read a wrong value from successfulOps if second phase is very fast ie. count etc.
553552
// increment all the "future" shards to update the total ops since we some may work and some may not...
554553
// and when that happens, we break on total ops, so we must maintain them
555-
successfulShardExecution(shardIt);
554+
successfulShardExecution();
556555
}
557556

558-
private void successfulShardExecution(SearchShardIterator shardsIt) {
559-
final int remainingOpsOnIterator;
560-
if (shardsIt.skip()) {
561-
// It's possible that we're skipping a shard that's unavailable
562-
// but its range was available in the IndexMetadata, in that
563-
// case the shardsIt.remaining() would be 0, expectedTotalOps
564-
// accounts for unavailable shards too.
565-
remainingOpsOnIterator = Math.max(shardsIt.remaining(), 1);
566-
} else {
567-
remainingOpsOnIterator = shardsIt.remaining() + 1;
568-
}
569-
final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator);
570-
if (xTotalOps == expectedTotalOps) {
557+
private void successfulShardExecution() {
558+
if ((int) OUTSTANDING_SHARDS.getAndAdd(this, -1) == 1) {
571559
onPhaseDone();
572-
} else if (xTotalOps > expectedTotalOps) {
573-
throw new AssertionError(
574-
"unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]",
575-
new SearchPhaseExecutionException(getName(), "Shard failures", null, buildShardFailures())
576-
);
577560
}
578561
}
579562

@@ -640,7 +623,7 @@ private SearchResponse buildSearchResponse(
640623
scrollId,
641624
getNumShards(),
642625
numSuccess,
643-
skippedOps.get(),
626+
toSkipShardsIts.size(),
644627
buildTookInMillis(),
645628
failures,
646629
clusters,

server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ protected void run() {
144144
hits.decRef();
145145
}
146146
} finally {
147-
mockSearchPhaseContext.execute(() -> {});
148147
var resp = mockSearchPhaseContext.searchResponse.get();
149148
if (resp != null) {
150149
resp.decRef();

0 commit comments

Comments
 (0)