Skip to content

Commit 35c3fa9

Browse files
Simplify counting in AbstractSearchAsyncAction (#120593) (#122193)
No need to do this so complicated, just count down one when we're actually done with a specific shard id.
1 parent af9f562 commit 35c3fa9

File tree

4 files changed

+19
-93
lines changed

4 files changed

+19
-93
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,13 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
9191
private final Object shardFailuresMutex = new Object();
9292
private final AtomicBoolean hasShardResponse = new AtomicBoolean(false);
9393
private final AtomicInteger successfulOps = new AtomicInteger();
94-
private final AtomicInteger skippedOps = new AtomicInteger();
9594
private final SearchTimeProvider timeProvider;
9695
private final SearchResponse.Clusters clusters;
9796

9897
protected final GroupShardsIterator<SearchShardIterator> toSkipShardsIts;
9998
protected final GroupShardsIterator<SearchShardIterator> shardsIts;
10099
private final SearchShardIterator[] shardIterators;
101-
private final int expectedTotalOps;
102-
private final AtomicInteger totalOps = new AtomicInteger();
100+
private final AtomicInteger outstandingShards;
103101
private final int maxConcurrentRequestsPerNode;
104102
private final Map<String, PendingExecutions> pendingExecutionsPerNode = new ConcurrentHashMap<>();
105103
private final boolean throttleConcurrentRequests;
@@ -140,18 +138,12 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
140138
}
141139
this.toSkipShardsIts = new GroupShardsIterator<>(toSkipIterators);
142140
this.shardsIts = new GroupShardsIterator<>(iterators);
143-
141+
outstandingShards = new AtomicInteger(shardsIts.size());
144142
this.shardIterators = iterators.toArray(new SearchShardIterator[0]);
145143
// we later compute the shard index based on the natural order of the shards
146144
// that participate in the search request. This means that this number is
147145
// consistent between two requests that target the same shards.
148146
Arrays.sort(shardIterators);
149-
150-
// we need to add 1 for non active partition, since we count it in the total. This means for each shard in the iterator we sum up
151-
// it's number of active shards but use 1 as the default if no replica of a shard is active at this point.
152-
// on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
153-
// we process hence we add one for the non active partition here.
154-
this.expectedTotalOps = shardsIts.totalSizeWith1ForEmpty();
155147
this.maxConcurrentRequestsPerNode = maxConcurrentRequestsPerNode;
156148
// in the case were we have less shards than maxConcurrentRequestsPerNode we don't need to throttle
157149
this.throttleConcurrentRequests = maxConcurrentRequestsPerNode < shardsIts.size();
@@ -261,9 +253,8 @@ public final void run() {
261253

262254
void skipShard(SearchShardIterator iterator) {
263255
successfulOps.incrementAndGet();
264-
skippedOps.incrementAndGet();
265256
assert iterator.skip();
266-
successfulShardExecution(iterator);
257+
successfulShardExecution();
267258
}
268259

269260
private boolean checkMinimumVersion(GroupShardsIterator<SearchShardIterator> shardsIts) {
@@ -405,7 +396,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
405396
"Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})",
406397
discrepancy,
407398
successfulOps.get(),
408-
skippedOps.get(),
399+
toSkipShardsIts.size(),
409400
getNumShards(),
410401
currentPhase.getName()
411402
);
@@ -474,17 +465,14 @@ private void onShardFailure(final int shardIndex, SearchShardTarget shard, final
474465
}
475466
onShardGroupFailure(shardIndex, shard, e);
476467
}
477-
final int totalOps = this.totalOps.incrementAndGet();
478-
if (totalOps == expectedTotalOps) {
479-
onPhaseDone();
480-
} else if (totalOps > expectedTotalOps) {
481-
throw new AssertionError(
482-
"unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]",
483-
new SearchPhaseExecutionException(getName(), "Shard failures", null, buildShardFailures())
484-
);
468+
if (lastShard == false) {
469+
performPhaseOnShard(shardIndex, shardIt, nextShard);
485470
} else {
486-
if (lastShard == false) {
487-
performPhaseOnShard(shardIndex, shardIt, nextShard);
471+
// count down outstanding shards, we're done with this shard as there's no more copies to try
472+
final int outstanding = outstandingShards.decrementAndGet();
473+
assert outstanding >= 0 : "outstanding: " + outstanding;
474+
if (outstanding == 0) {
475+
onPhaseDone();
488476
}
489477
}
490478
}
@@ -561,10 +549,10 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) {
561549
if (logger.isTraceEnabled()) {
562550
logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null);
563551
}
564-
results.consumeResult(result, () -> onShardResultConsumed(result, shardIt));
552+
results.consumeResult(result, () -> onShardResultConsumed(result));
565553
}
566554

567-
private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
555+
private void onShardResultConsumed(Result result) {
568556
successfulOps.incrementAndGet();
569557
// clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
570558
// so its ok concurrency wise to miss potentially the shard failures being created because of another failure
@@ -578,28 +566,14 @@ private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
578566
// cause the successor to read a wrong value from successfulOps if second phase is very fast ie. count etc.
579567
// increment all the "future" shards to update the total ops since we some may work and some may not...
580568
// and when that happens, we break on total ops, so we must maintain them
581-
successfulShardExecution(shardIt);
569+
successfulShardExecution();
582570
}
583571

584-
private void successfulShardExecution(SearchShardIterator shardsIt) {
585-
final int remainingOpsOnIterator;
586-
if (shardsIt.skip()) {
587-
// It's possible that we're skipping a shard that's unavailable
588-
// but its range was available in the IndexMetadata, in that
589-
// case the shardsIt.remaining() would be 0, expectedTotalOps
590-
// accounts for unavailable shards too.
591-
remainingOpsOnIterator = Math.max(shardsIt.remaining(), 1);
592-
} else {
593-
remainingOpsOnIterator = shardsIt.remaining() + 1;
594-
}
595-
final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator);
596-
if (xTotalOps == expectedTotalOps) {
572+
private void successfulShardExecution() {
573+
final int outstanding = outstandingShards.decrementAndGet();
574+
assert outstanding >= 0 : "outstanding: " + outstanding;
575+
if (outstanding == 0) {
597576
onPhaseDone();
598-
} else if (xTotalOps > expectedTotalOps) {
599-
throw new AssertionError(
600-
"unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]",
601-
new SearchPhaseExecutionException(getName(), "Shard failures", null, buildShardFailures())
602-
);
603577
}
604578
}
605579

@@ -666,7 +640,7 @@ private SearchResponse buildSearchResponse(
666640
scrollId,
667641
getNumShards(),
668642
numSuccess,
669-
skippedOps.get(),
643+
toSkipShardsIts.size(),
670644
buildTookInMillis(),
671645
failures,
672646
clusters,

server/src/main/java/org/elasticsearch/cluster/routing/GroupShardsIterator.java

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,6 @@ public GroupShardsIterator(List<ShardIt> iterators) {
4141
this.iterators = iterators;
4242
}
4343

44-
/**
45-
* Returns the total number of shards within all groups
46-
* @return total number of shards
47-
*/
48-
public int totalSize() {
49-
return iterators.stream().mapToInt(Countable::size).sum();
50-
}
51-
52-
/**
53-
* Returns the total number of shards plus the number of empty groups
54-
* @return number of shards and empty groups
55-
*/
56-
public int totalSizeWith1ForEmpty() {
57-
int size = 0;
58-
for (ShardIt shard : iterators) {
59-
size += Math.max(1, shard.size());
60-
}
61-
return size;
62-
}
63-
6444
/**
6545
* Return the number of groups
6646
* @return number of groups

server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ public void run() {
144144
hits.decRef();
145145
}
146146
} finally {
147-
mockSearchPhaseContext.execute(() -> {});
148147
var resp = mockSearchPhaseContext.searchResponse.get();
149148
if (resp != null) {
150149
resp.decRef();

server/src/test/java/org/elasticsearch/cluster/routing/GroupShardsIteratorTests.java

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,33 +38,6 @@ private static List<ShardRouting> randomShardRoutings(ShardId shardId, int numRe
3838
return shardRoutings;
3939
}
4040

41-
public void testSize() {
42-
List<ShardIterator> list = new ArrayList<>();
43-
Index index = new Index("foo", "na");
44-
{
45-
ShardId shardId = new ShardId(index, 0);
46-
list.add(new PlainShardIterator(shardId, randomShardRoutings(shardId, 2)));
47-
}
48-
list.add(new PlainShardIterator(new ShardId(index, 1), Collections.emptyList()));
49-
{
50-
ShardId shardId = new ShardId(index, 2);
51-
list.add(new PlainShardIterator(shardId, randomShardRoutings(shardId, 0)));
52-
}
53-
index = new Index("foo_1", "na");
54-
{
55-
ShardId shardId = new ShardId(index, 0);
56-
list.add(new PlainShardIterator(shardId, randomShardRoutings(shardId, 0)));
57-
}
58-
{
59-
ShardId shardId = new ShardId(index, 1);
60-
list.add(new PlainShardIterator(shardId, randomShardRoutings(shardId, 0)));
61-
}
62-
GroupShardsIterator<ShardIterator> iter = new GroupShardsIterator<>(list);
63-
assertEquals(7, iter.totalSizeWith1ForEmpty());
64-
assertEquals(5, iter.size());
65-
assertEquals(6, iter.totalSize());
66-
}
67-
6841
public void testIterate() {
6942
List<ShardIterator> list = new ArrayList<>();
7043
Index index = new Index("foo", "na");

0 commit comments

Comments
 (0)