diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 7bde65abbc7db..3ba78c5cc778c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -86,7 +86,6 @@ abstract class AbstractSearchAsyncAction exten private final SetOnce> shardFailures = new SetOnce<>(); private final Object shardFailuresMutex = new Object(); private final AtomicBoolean hasShardResponse = new AtomicBoolean(false); - private final AtomicInteger successfulOps; private final SearchTimeProvider timeProvider; private final SearchResponse.Clusters clusters; @@ -135,7 +134,6 @@ abstract class AbstractSearchAsyncAction exten this.skippedCount = skipped; this.shardsIts = iterators; outstandingShards = new AtomicInteger(iterators.size()); - successfulOps = new AtomicInteger(skipped); this.shardIterators = iterators.toArray(new SearchShardIterator[0]); // we later compute the shard index based on the natural order of the shards // that participate in the search request. This means that this number is @@ -328,32 +326,16 @@ protected void executeNextPhase(String currentPhase, Supplier nextP } else { Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (allowPartialResults == false && successfulOps.get() != getNumShards()) { + if (allowPartialResults == false && shardSearchFailures.length > 0) { // check if there are actual failures in the atomic array since // successful retries can reset the failures to null - if (shardSearchFailures.length > 0) { - if (logger.isDebugEnabled()) { - int numShardFailures = shardSearchFailures.length; - shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures); - Throwable cause = ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0]; - logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase), cause); - } - onPhaseFailure(currentPhase, "Partial shards failure", null); - } else { - int discrepancy = getNumShards() - successfulOps.get(); - assert discrepancy > 0 : "discrepancy: " + discrepancy; - if (logger.isDebugEnabled()) { - logger.debug( - "Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})", - discrepancy, - successfulOps.get(), - skippedCount, - getNumShards(), - currentPhase - ); - } - onPhaseFailure(currentPhase, "Partial shards failure (" + discrepancy + " shards unavailable)", null); + if (logger.isDebugEnabled()) { + int numShardFailures = shardSearchFailures.length; + shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures); + Throwable cause = ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0]; + logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase), cause); } + onPhaseFailure(currentPhase, "Partial shards failure", null); return; } var nextPhase = nextPhaseSupplier.get(); @@ -466,19 +448,10 @@ void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Excepti } } ShardSearchFailure failure = shardFailures.get(shardIndex); - if (failure == null) { + // the failure is already present, try and not override it with an exception that is less meaningless + // for example, getting illegal shard state + if (failure == null || (TransportActions.isReadOverrideException(e) && e instanceof SearchContextMissingException == false)) { shardFailures.set(shardIndex, new ShardSearchFailure(e, shardTarget)); - } else { - // the failure is already present, try and not override it with an exception that is less meaningless - // for example, getting illegal shard state - if (TransportActions.isReadOverrideException(e) && (e instanceof SearchContextMissingException == false)) { - shardFailures.set(shardIndex, new ShardSearchFailure(e, shardTarget)); - } - } - - if (results.hasResult(shardIndex)) { - assert failure == null : "shard failed before but shouldn't: " + failure; - successfulOps.decrementAndGet(); // if this shard was successful before (initial phase) we have to adjust the counter } } } @@ -502,7 +475,6 @@ protected void onShardResult(Result result) { } private void onShardResultConsumed(Result result) { - successfulOps.incrementAndGet(); // clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level // so its ok concurrency wise to miss potentially the shard failures being created because of another failure // in the #addShardFailure, because by definition, it will happen on *another* shardIndex @@ -510,15 +482,6 @@ private void onShardResultConsumed(Result result) { if (shardFailures != null) { shardFailures.set(result.getShardIndex(), null); } - // we need to increment successful ops first before we compare the exit condition otherwise if we - // are fast we could concurrently update totalOps but then preempt one of the threads which can - // cause the successor to read a wrong value from successfulOps if second phase is very fast ie. count etc. - // increment all the "future" shards to update the total ops since we some may work and some may not... - // and when that happens, we break on total ops, so we must maintain them - successfulShardExecution(); - } - - private void successfulShardExecution() { final int outstanding = outstandingShards.decrementAndGet(); assert outstanding >= 0 : "outstanding: " + outstanding; if (outstanding == 0) { @@ -580,15 +543,12 @@ private SearchResponse buildSearchResponse( String scrollId, BytesReference searchContextId ) { - int numSuccess = successfulOps.get(); - int numFailures = failures.length; - assert numSuccess + numFailures == getNumShards() - : "numSuccess(" + numSuccess + ") + numFailures(" + numFailures + ") != totalShards(" + getNumShards() + ")"; + final int numShards = getNumShards(); return new SearchResponse( internalSearchResponse, scrollId, - getNumShards(), - numSuccess, + numShards, + numShards - failures.length, skippedCount, buildTookInMillis(), failures, diff --git a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java index abe7e893977f4..086bbd9053d4b 100644 --- a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java @@ -215,33 +215,6 @@ public void testOnPhaseFailure() { assertEquals(requestIds, releasedContexts); } - public void testShardNotAvailableWithDisallowPartialFailures() { - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false); - AtomicReference exception = new AtomicReference<>(); - ActionListener listener = ActionListener.wrap(response -> fail("onResponse should not be called"), exception::set); - int numShards = randomIntBetween(2, 10); - ArraySearchPhaseResults phaseResults = new ArraySearchPhaseResults<>(numShards); - AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong()); - // skip one to avoid the "all shards failed" failure. - action.onShardResult(new SearchPhaseResult() { - @Override - public int getShardIndex() { - return 0; - } - - @Override - public SearchShardTarget getSearchShardTarget() { - return new SearchShardTarget(null, null, null); - } - }); - assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class)); - SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException) exception.get(); - assertEquals("Partial shards failure (" + (numShards - 1) + " shards unavailable)", searchPhaseExecutionException.getMessage()); - assertEquals("test", searchPhaseExecutionException.getPhaseName()); - assertEquals(0, searchPhaseExecutionException.shardFailures().length); - assertEquals(0, searchPhaseExecutionException.getSuppressed().length); - } - private static ArraySearchPhaseResults phaseResults( Set contextIds, List> nodeLookups,