3838import java .util .List ;
3939import java .util .Map ;
4040import java .util .Objects ;
41+ import java .util .Set ;
42+ import java .util .concurrent .ConcurrentHashMap ;
4143import java .util .concurrent .Executor ;
42- import java .util .concurrent .atomic .AtomicReferenceArray ;
4344import java .util .function .BiFunction ;
4445
4546import static org .elasticsearch .core .Strings .format ;
@@ -277,12 +278,11 @@ private Map<SendingTarget, List<SearchShardIterator>> groupByNode(List<SearchSha
277278 private class Round extends AbstractRunnable {
278279 private final List <SearchShardIterator > shards ;
279280 private final CountDown countDown ;
280- private final AtomicReferenceArray < Exception > failedResponses ;
281+ private final Set < Integer > failedResponses = ConcurrentHashMap . newKeySet () ;
281282
282283 Round (List <SearchShardIterator > shards ) {
283284 this .shards = shards ;
284285 this .countDown = new CountDown (shards .size ());
285- this .failedResponses = new AtomicReferenceArray <>(shardsIts .size ());
286286 }
287287
288288 @ Override
@@ -296,9 +296,7 @@ protected void doRun() {
296296
297297 if (entry .getKey ().nodeId == null ) {
298298 // no target node: just mark the requests as failed
299- for (CanMatchNodeRequest .Shard shard : shardLevelRequests ) {
300- onOperationFailed (shard .getShardRequestIndex (), null );
301- }
299+ onAllFailed (shardLevelRequests );
302300 continue ;
303301 }
304302
@@ -321,37 +319,39 @@ public void onResponse(CanMatchNodeResponse canMatchNodeResponse) {
321319 } else {
322320 Exception failure = response .getException ();
323321 assert failure != null ;
324- onOperationFailed (shardLevelRequests .get (i ).getShardRequestIndex (), failure );
322+ onOperationFailed (shardLevelRequests .get (i ).getShardRequestIndex ());
325323 }
326324 }
327325 }
328326
329327 @ Override
330328 public void onFailure (Exception e ) {
331- for (CanMatchNodeRequest .Shard shard : shardLevelRequests ) {
332- onOperationFailed (shard .getShardRequestIndex (), e );
333- }
329+ onAllFailed (shardLevelRequests );
334330 }
335331 }
336332 );
337333 } catch (Exception e ) {
338- for (CanMatchNodeRequest .Shard shard : shardLevelRequests ) {
339- onOperationFailed (shard .getShardRequestIndex (), e );
340- }
334+ onAllFailed (shardLevelRequests );
341335 }
342336 }
343337 }
344338
339+ private void onAllFailed (List <CanMatchNodeRequest .Shard > shardLevelRequests ) {
340+ for (CanMatchNodeRequest .Shard shard : shardLevelRequests ) {
341+ onOperationFailed (shard .getShardRequestIndex ());
342+ }
343+ }
344+
345345 private void onOperation (int idx , CanMatchShardResponse response ) {
346- failedResponses .set (idx , null );
346+ failedResponses .add (idx );
347347 consumeResult (response );
348348 if (countDown .countDown ()) {
349349 finishRound ();
350350 }
351351 }
352352
353- private void onOperationFailed (int idx , Exception e ) {
354- failedResponses .set (idx , e );
353+ private void onOperationFailed (int idx ) {
354+ failedResponses .add (idx );
355355 // we have to carry over shard failures in order to account for them in the response.
356356 consumeResult (idx , true , null );
357357 if (countDown .countDown ()) {
@@ -363,8 +363,7 @@ private void finishRound() {
363363 List <SearchShardIterator > remainingShards = new ArrayList <>();
364364 for (SearchShardIterator ssi : shards ) {
365365 int shardIndex = shardItIndexMap .get (ssi );
366- Exception failedResponse = failedResponses .get (shardIndex );
367- if (failedResponse != null ) {
366+ if (failedResponses .contains (shardIndex )) {
368367 remainingShards .add (ssi );
369368 }
370369 }
0 commit comments