From e664c040e4bcf023d374bc4f0733e5f0c6a11e92 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Wed, 5 Feb 2025 11:56:29 +0100 Subject: [PATCH] Simplify counting in AbstractSearchAsyncAction (#120593) No need to do this so complicated, just count down one when we're actually done with a specific shard id. --- .../search/AbstractSearchAsyncAction.java | 64 ++++++------------- .../cluster/routing/GroupShardsIterator.java | 20 ------ .../action/search/ExpandSearchPhaseTests.java | 1 - .../routing/GroupShardsIteratorTests.java | 27 -------- 4 files changed, 19 insertions(+), 93 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 4030e8e15ce3a..ec1f2cedfd7d7 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -91,15 +91,13 @@ abstract class AbstractSearchAsyncAction exten private final Object shardFailuresMutex = new Object(); private final AtomicBoolean hasShardResponse = new AtomicBoolean(false); private final AtomicInteger successfulOps = new AtomicInteger(); - private final AtomicInteger skippedOps = new AtomicInteger(); private final SearchTimeProvider timeProvider; private final SearchResponse.Clusters clusters; protected final GroupShardsIterator toSkipShardsIts; protected final GroupShardsIterator shardsIts; private final SearchShardIterator[] shardIterators; - private final int expectedTotalOps; - private final AtomicInteger totalOps = new AtomicInteger(); + private final AtomicInteger outstandingShards; private final int maxConcurrentRequestsPerNode; private final Map pendingExecutionsPerNode = new ConcurrentHashMap<>(); private final boolean throttleConcurrentRequests; @@ -140,18 +138,12 @@ abstract class AbstractSearchAsyncAction exten } this.toSkipShardsIts = new GroupShardsIterator<>(toSkipIterators); this.shardsIts = new GroupShardsIterator<>(iterators); - + outstandingShards = new AtomicInteger(shardsIts.size()); this.shardIterators = iterators.toArray(new SearchShardIterator[0]); // we later compute the shard index based on the natural order of the shards // that participate in the search request. This means that this number is // consistent between two requests that target the same shards. Arrays.sort(shardIterators); - - // we need to add 1 for non active partition, since we count it in the total. This means for each shard in the iterator we sum up - // it's number of active shards but use 1 as the default if no replica of a shard is active at this point. - // on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result - // we process hence we add one for the non active partition here. - this.expectedTotalOps = shardsIts.totalSizeWith1ForEmpty(); this.maxConcurrentRequestsPerNode = maxConcurrentRequestsPerNode; // in the case were we have less shards than maxConcurrentRequestsPerNode we don't need to throttle this.throttleConcurrentRequests = maxConcurrentRequestsPerNode < shardsIts.size(); @@ -261,9 +253,8 @@ public final void run() { void skipShard(SearchShardIterator iterator) { successfulOps.incrementAndGet(); - skippedOps.incrementAndGet(); assert iterator.skip(); - successfulShardExecution(iterator); + successfulShardExecution(); } private boolean checkMinimumVersion(GroupShardsIterator shardsIts) { @@ -405,7 +396,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier "Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})", discrepancy, successfulOps.get(), - skippedOps.get(), + toSkipShardsIts.size(), getNumShards(), currentPhase.getName() ); @@ -474,17 +465,14 @@ private void onShardFailure(final int shardIndex, SearchShardTarget shard, final } onShardGroupFailure(shardIndex, shard, e); } - final int totalOps = this.totalOps.incrementAndGet(); - if (totalOps == expectedTotalOps) { - onPhaseDone(); - } else if (totalOps > expectedTotalOps) { - throw new AssertionError( - "unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]", - new SearchPhaseExecutionException(getName(), "Shard failures", null, buildShardFailures()) - ); + if (lastShard == false) { + performPhaseOnShard(shardIndex, shardIt, nextShard); } else { - if (lastShard == false) { - performPhaseOnShard(shardIndex, shardIt, nextShard); + // count down outstanding shards, we're done with this shard as there's no more copies to try + final int outstanding = outstandingShards.decrementAndGet(); + assert outstanding >= 0 : "outstanding: " + outstanding; + if (outstanding == 0) { + onPhaseDone(); } } } @@ -561,10 +549,10 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) { if (logger.isTraceEnabled()) { logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null); } - results.consumeResult(result, () -> onShardResultConsumed(result, shardIt)); + results.consumeResult(result, () -> onShardResultConsumed(result)); } - private void onShardResultConsumed(Result result, SearchShardIterator shardIt) { + private void onShardResultConsumed(Result result) { successfulOps.incrementAndGet(); // clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level // so its ok concurrency wise to miss potentially the shard failures being created because of another failure @@ -578,28 +566,14 @@ private void onShardResultConsumed(Result result, SearchShardIterator shardIt) { // cause the successor to read a wrong value from successfulOps if second phase is very fast ie. count etc. // increment all the "future" shards to update the total ops since we some may work and some may not... // and when that happens, we break on total ops, so we must maintain them - successfulShardExecution(shardIt); + successfulShardExecution(); } - private void successfulShardExecution(SearchShardIterator shardsIt) { - final int remainingOpsOnIterator; - if (shardsIt.skip()) { - // It's possible that we're skipping a shard that's unavailable - // but its range was available in the IndexMetadata, in that - // case the shardsIt.remaining() would be 0, expectedTotalOps - // accounts for unavailable shards too. - remainingOpsOnIterator = Math.max(shardsIt.remaining(), 1); - } else { - remainingOpsOnIterator = shardsIt.remaining() + 1; - } - final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator); - if (xTotalOps == expectedTotalOps) { + private void successfulShardExecution() { + final int outstanding = outstandingShards.decrementAndGet(); + assert outstanding >= 0 : "outstanding: " + outstanding; + if (outstanding == 0) { onPhaseDone(); - } else if (xTotalOps > expectedTotalOps) { - throw new AssertionError( - "unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]", - new SearchPhaseExecutionException(getName(), "Shard failures", null, buildShardFailures()) - ); } } @@ -666,7 +640,7 @@ private SearchResponse buildSearchResponse( scrollId, getNumShards(), numSuccess, - skippedOps.get(), + toSkipShardsIts.size(), buildTookInMillis(), failures, clusters, diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/GroupShardsIterator.java b/server/src/main/java/org/elasticsearch/cluster/routing/GroupShardsIterator.java index 32f9530e4b185..590a1bbb16928 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/GroupShardsIterator.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/GroupShardsIterator.java @@ -41,26 +41,6 @@ public GroupShardsIterator(List iterators) { this.iterators = iterators; } - /** - * Returns the total number of shards within all groups - * @return total number of shards - */ - public int totalSize() { - return iterators.stream().mapToInt(Countable::size).sum(); - } - - /** - * Returns the total number of shards plus the number of empty groups - * @return number of shards and empty groups - */ - public int totalSizeWith1ForEmpty() { - int size = 0; - for (ShardIt shard : iterators) { - size += Math.max(1, shard.size()); - } - return size; - } - /** * Return the number of groups * @return number of groups diff --git a/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java index 5fb70500d515f..5ba5a4f903e83 100644 --- a/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java @@ -144,7 +144,6 @@ public void run() { hits.decRef(); } } finally { - mockSearchPhaseContext.execute(() -> {}); var resp = mockSearchPhaseContext.searchResponse.get(); if (resp != null) { resp.decRef(); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/GroupShardsIteratorTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/GroupShardsIteratorTests.java index 8e111c3676284..d354658396a0b 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/GroupShardsIteratorTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/GroupShardsIteratorTests.java @@ -38,33 +38,6 @@ private static List randomShardRoutings(ShardId shardId, int numRe return shardRoutings; } - public void testSize() { - List list = new ArrayList<>(); - Index index = new Index("foo", "na"); - { - ShardId shardId = new ShardId(index, 0); - list.add(new PlainShardIterator(shardId, randomShardRoutings(shardId, 2))); - } - list.add(new PlainShardIterator(new ShardId(index, 1), Collections.emptyList())); - { - ShardId shardId = new ShardId(index, 2); - list.add(new PlainShardIterator(shardId, randomShardRoutings(shardId, 0))); - } - index = new Index("foo_1", "na"); - { - ShardId shardId = new ShardId(index, 0); - list.add(new PlainShardIterator(shardId, randomShardRoutings(shardId, 0))); - } - { - ShardId shardId = new ShardId(index, 1); - list.add(new PlainShardIterator(shardId, randomShardRoutings(shardId, 0))); - } - GroupShardsIterator iter = new GroupShardsIterator<>(list); - assertEquals(7, iter.totalSizeWith1ForEmpty()); - assertEquals(5, iter.size()); - assertEquals(6, iter.totalSize()); - } - public void testIterate() { List list = new ArrayList<>(); Index index = new Index("foo", "na");