Skip to content

Commit ace066d

Browse files
authored
Fix race condition when resolving new location for multiple shards at once (#128062) (#128175)
(cherry picked from commit f4b6086)
1 parent 708a1ae commit ace066d

File tree

2 files changed

+48
-22
lines changed

2 files changed

+48
-22
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
144144
var pendingRetries = new HashSet<ShardId>();
145145
for (ShardId shardId : pendingShardIds) {
146146
if (targetShards.getShard(shardId).remainingNodes.isEmpty()) {
147-
var failure = shardFailures.get(shardId);
148-
if (failure != null && failure.fatal == false && failure.failure instanceof NoShardAvailableActionException) {
147+
if (isRetryableFailure(shardFailures.get(shardId))) {
149148
pendingRetries.add(shardId);
150149
}
151150
}
@@ -156,7 +155,8 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
156155
}
157156
}
158157
for (ShardId shardId : pendingShardIds) {
159-
if (targetShards.getShard(shardId).remainingNodes.isEmpty()) {
158+
if (targetShards.getShard(shardId).remainingNodes.isEmpty()
159+
&& (isRetryableFailure(shardFailures.get(shardId)) == false || pendingRetries.contains(shardId))) {
160160
shardFailures.compute(
161161
shardId,
162162
(k, v) -> new ShardFailure(
@@ -327,6 +327,10 @@ record NodeRequest(DiscoveryNode node, List<ShardId> shardIds, Map<Index, AliasF
327327

328328
private record ShardFailure(boolean fatal, Exception failure) {}
329329

330+
private static boolean isRetryableFailure(ShardFailure failure) {
331+
return failure != null && failure.fatal == false && failure.failure instanceof NoShardAvailableActionException;
332+
}
333+
330334
/**
331335
* Selects the next nodes to send requests to. Limits to at most one outstanding request per node.
332336
* If there is already a request in-flight to a node, another request will not be sent to the same node

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@
5151
import java.util.concurrent.atomic.AtomicBoolean;
5252
import java.util.concurrent.atomic.AtomicInteger;
5353
import java.util.function.Function;
54-
import java.util.stream.Collectors;
5554

55+
import static java.util.stream.Collectors.toMap;
5656
import static org.elasticsearch.xpack.esql.plugin.DataNodeRequestSender.NodeRequest;
5757
import static org.hamcrest.Matchers.contains;
5858
import static org.hamcrest.Matchers.containsString;
@@ -402,6 +402,32 @@ public void testRetryMovedShard() {
402402
assertThat(attempt.get(), equalTo(3));
403403
}
404404

405+
public void testRetryMultipleMovedShards() {
406+
var attempt = new AtomicInteger(0);
407+
var response = safeGet(
408+
sendRequests(
409+
randomBoolean(),
410+
-1,
411+
List.of(targetShard(shard1, node1), targetShard(shard2, node2), targetShard(shard3, node3)),
412+
shardIds -> shardIds.stream().collect(toMap(Function.identity(), shardId -> List.of(randomFrom(node1, node2, node3)))),
413+
(node, shardIds, aliasFilters, listener) -> runWithDelay(
414+
() -> listener.onResponse(
415+
attempt.incrementAndGet() <= 6
416+
? new DataNodeComputeResponse(
417+
List.of(),
418+
shardIds.stream().collect(toMap(Function.identity(), ShardNotFoundException::new))
419+
)
420+
: new DataNodeComputeResponse(List.of(), Map.of())
421+
)
422+
)
423+
)
424+
);
425+
assertThat(response.totalShards, equalTo(3));
426+
assertThat(response.successfulShards, equalTo(3));
427+
assertThat(response.skippedShards, equalTo(0));
428+
assertThat(response.failedShards, equalTo(0));
429+
}
430+
405431
public void testDoesNotRetryMovedShardIndefinitely() {
406432
var attempt = new AtomicInteger(0);
407433
var response = safeGet(sendRequests(true, -1, List.of(targetShard(shard1, node1)), shardIds -> {
@@ -463,28 +489,28 @@ public void testRetryUnassignedShardWithoutPartialResults() {
463489

464490
);
465491
expectThrows(NoShardAvailableActionException.class, containsString("no such shard"), future::actionGet);
492+
assertThat(attempt.get(), equalTo(1));
466493
}
467494

468495
public void testRetryUnassignedShardWithPartialResults() {
469-
var response = safeGet(
470-
sendRequests(
471-
true,
472-
-1,
473-
List.of(targetShard(shard1, node1), targetShard(shard2, node2)),
474-
shardIds -> Map.of(shard1, List.of()),
475-
(node, shardIds, aliasFilters, listener) -> runWithDelay(
476-
() -> listener.onResponse(
477-
Objects.equals(shardIds, List.of(shard2))
478-
? new DataNodeComputeResponse(List.of(), Map.of())
479-
: new DataNodeComputeResponse(List.of(), Map.of(shard1, new ShardNotFoundException(shard1)))
480-
)
496+
var attempt = new AtomicInteger(0);
497+
var response = safeGet(sendRequests(true, -1, List.of(targetShard(shard1, node1), targetShard(shard2, node2)), shardIds -> {
498+
attempt.incrementAndGet();
499+
return Map.of(shard1, List.of());
500+
},
501+
(node, shardIds, aliasFilters, listener) -> runWithDelay(
502+
() -> listener.onResponse(
503+
Objects.equals(shardIds, List.of(shard2))
504+
? new DataNodeComputeResponse(List.of(), Map.of())
505+
: new DataNodeComputeResponse(List.of(), Map.of(shard1, new ShardNotFoundException(shard1)))
481506
)
482507
)
483-
);
508+
));
484509
assertThat(response.totalShards, equalTo(2));
485510
assertThat(response.successfulShards, equalTo(1));
486511
assertThat(response.skippedShards, equalTo(0));
487512
assertThat(response.failedShards, equalTo(1));
513+
assertThat(attempt.get(), equalTo(1));
488514
}
489515

490516
static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) {
@@ -553,11 +579,7 @@ PlainActionFuture<ComputeResponse> sendRequests(
553579
void searchShards(Set<String> concreteIndices, ActionListener<TargetShards> listener) {
554580
runWithDelay(
555581
() -> listener.onResponse(
556-
new TargetShards(
557-
shards.stream().collect(Collectors.toMap(TargetShard::shardId, Function.identity())),
558-
shards.size(),
559-
0
560-
)
582+
new TargetShards(shards.stream().collect(toMap(TargetShard::shardId, Function.identity())), shards.size(), 0)
561583
)
562584
);
563585
}

0 commit comments

Comments
 (0)