diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index e42f8127c5e97..24aabc59d84a6 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -38,8 +38,9 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.BiFunction; import static org.elasticsearch.core.Strings.format; @@ -226,25 +227,17 @@ private void runCoordinatorRewritePhase() { } private void consumeResult(boolean canMatch, ShardSearchRequest request) { - CanMatchShardResponse result = new CanMatchShardResponse(canMatch, null); - result.setShardIndex(request.shardRequestIndex()); - consumeResult(result); + consumeResult(request.shardRequestIndex(), canMatch, null); } - private void consumeResult(CanMatchShardResponse result) { - final boolean canMatch = result.canMatch(); - final MinAndMax minAndMax = result.estimatedMinAndMax(); - if (canMatch || minAndMax != null) { - consumeResult(result.getShardIndex(), canMatch, minAndMax); - } - } - - private synchronized void consumeResult(int shardIndex, boolean canMatch, MinAndMax minAndMax) { + private void consumeResult(int shardIndex, boolean canMatch, MinAndMax minAndMax) { if (canMatch) { - possibleMatches.set(shardIndex); - numPossibleMatches++; + synchronized (this) { + possibleMatches.set(shardIndex); + numPossibleMatches++; + minAndMaxes[shardIndex] = minAndMax; + } } - minAndMaxes[shardIndex] = minAndMax; } private void checkNoMissingShards(List shards) { @@ -277,12 +270,11 @@ private Map> groupByNode(List shards; private final CountDown countDown; - private final AtomicReferenceArray failedResponses; + private final Set failedResponses = ConcurrentHashMap.newKeySet(); Round(List shards) { this.shards = shards; this.countDown = new CountDown(shards.size()); - this.failedResponses = new AtomicReferenceArray<>(shardsIts.size()); } @Override @@ -295,9 +287,10 @@ protected void doRun() { List shardLevelRequests = canMatchNodeRequest.getShardLevelRequests(); if (entry.getKey().nodeId == null) { - // no target node: just mark the requests as failed + // no target node: just mark as matching and have the next phase fail the shard operation if needed for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { - onOperationFailed(shard.getShardRequestIndex(), null); + var resp = new CanMatchShardResponse(true, null); + onOperation(shard.getShardRequestIndex(), resp); } continue; } @@ -321,37 +314,39 @@ public void onResponse(CanMatchNodeResponse canMatchNodeResponse) { } else { Exception failure = response.getException(); assert failure != null; - onOperationFailed(shardLevelRequests.get(i).getShardRequestIndex(), failure); + onOperationFailed(shardLevelRequests.get(i).getShardRequestIndex()); } } } @Override public void onFailure(Exception e) { - for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { - onOperationFailed(shard.getShardRequestIndex(), e); - } + onAllFailed(shardLevelRequests); } } ); } catch (Exception e) { - for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { - onOperationFailed(shard.getShardRequestIndex(), e); - } + onAllFailed(shardLevelRequests); } } } + private void onAllFailed(List shardLevelRequests) { + for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { + onOperationFailed(shard.getShardRequestIndex()); + } + } + private void onOperation(int idx, CanMatchShardResponse response) { - failedResponses.set(idx, null); - consumeResult(response); + failedResponses.remove(idx); + consumeResult(idx, response.canMatch(), response.estimatedMinAndMax()); if (countDown.countDown()) { finishRound(); } } - private void onOperationFailed(int idx, Exception e) { - failedResponses.set(idx, e); + private void onOperationFailed(int idx) { + failedResponses.add(idx); // we have to carry over shard failures in order to account for them in the response. consumeResult(idx, true, null); if (countDown.countDown()) { @@ -363,8 +358,7 @@ private void finishRound() { List remainingShards = new ArrayList<>(); for (SearchShardIterator ssi : shards) { int shardIndex = shardItIndexMap.get(ssi); - Exception failedResponse = failedResponses.get(shardIndex); - if (failedResponse != null) { + if (failedResponses.contains(shardIndex)) { remainingShards.add(ssi); } }