diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java index f76d9643e4a6d..4409201606d0a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java @@ -192,8 +192,7 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu var pendingRetries = new HashSet(); for (ShardId shardId : pendingShardIds) { if (targetShards.getShard(shardId).remainingNodes.isEmpty()) { - var failure = shardFailures.get(shardId); - if (failure != null && failure.fatal == false && failure.failure instanceof NoShardAvailableActionException) { + if (isRetryableFailure(shardFailures.get(shardId))) { pendingRetries.add(shardId); } } @@ -204,7 +203,8 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu } } for (ShardId shardId : pendingShardIds) { - if (targetShards.getShard(shardId).remainingNodes.isEmpty()) { + if (targetShards.getShard(shardId).remainingNodes.isEmpty() + && (isRetryableFailure(shardFailures.get(shardId)) == false || pendingRetries.contains(shardId))) { shardFailures.compute( shardId, (k, v) -> new ShardFailure( @@ -378,6 +378,10 @@ record NodeRequest(DiscoveryNode node, List shardIds, Map shardIds.stream().collect(toMap(Function.identity(), shardId -> List.of(randomFrom(node1, node2, node3)))), + (node, shardIds, aliasFilters, listener) -> runWithDelay( + () -> listener.onResponse( + attempt.incrementAndGet() <= 6 + ? new DataNodeComputeResponse( + DriverCompletionInfo.EMPTY, + shardIds.stream().collect(toMap(Function.identity(), ShardNotFoundException::new)) + ) + : new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()) + ) + ) + ) + ); + assertThat(response.totalShards, equalTo(3)); + assertThat(response.successfulShards, equalTo(3)); + assertThat(response.skippedShards, equalTo(0)); + assertThat(response.failedShards, equalTo(0)); + } + public void testDoesNotRetryMovedShardIndefinitely() { var attempt = new AtomicInteger(0); var response = safeGet(sendRequests(true, -1, List.of(targetShard(shard1, node1)), shardIds -> { @@ -517,28 +543,28 @@ public void testRetryUnassignedShardWithoutPartialResults() { ); expectThrows(NoShardAvailableActionException.class, containsString("no such shard"), future::actionGet); + assertThat(attempt.get(), equalTo(1)); } public void testRetryUnassignedShardWithPartialResults() { - var response = safeGet( - sendRequests( - true, - -1, - List.of(targetShard(shard1, node1), targetShard(shard2, node2)), - shardIds -> Map.of(shard1, List.of()), - (node, shardIds, aliasFilters, listener) -> runWithDelay( - () -> listener.onResponse( - Objects.equals(shardIds, List.of(shard2)) - ? new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()) - : new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard1, new ShardNotFoundException(shard1))) - ) + var attempt = new AtomicInteger(0); + var response = safeGet(sendRequests(true, -1, List.of(targetShard(shard1, node1), targetShard(shard2, node2)), shardIds -> { + attempt.incrementAndGet(); + return Map.of(shard1, List.of()); + }, + (node, shardIds, aliasFilters, listener) -> runWithDelay( + () -> listener.onResponse( + Objects.equals(shardIds, List.of(shard2)) + ? new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()) + : new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of(shard1, new ShardNotFoundException(shard1))) ) ) - ); + )); assertThat(response.totalShards, equalTo(2)); assertThat(response.successfulShards, equalTo(1)); assertThat(response.skippedShards, equalTo(0)); assertThat(response.failedShards, equalTo(1)); + assertThat(attempt.get(), equalTo(1)); } static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) { @@ -621,11 +647,7 @@ PlainActionFuture sendRequests( void searchShards(Set concreteIndices, ActionListener listener) { runWithDelay( () -> listener.onResponse( - new TargetShards( - shards.stream().collect(Collectors.toMap(TargetShard::shardId, Function.identity())), - shards.size(), - 0 - ) + new TargetShards(shards.stream().collect(toMap(TargetShard::shardId, Function.identity())), shards.size(), 0) ) ); }