Skip to content

Commit ad780e2

Browse files
committed
Fix DataNodeRequestSender (elastic#121999)
There are two issues in the current implementation: 1. We should use the list of shardIds from the request, rather than all targets, when removing failures for shards that have been successfully executed. 2. We should remove shardIds from the pending list once a failure is reported and abort execution at that point, as the results will be discarded. Closes elastic#121966
1 parent d3ff1cf commit ad780e2

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction;
3535

3636
import java.util.ArrayList;
37+
import java.util.Collections;
3738
import java.util.HashMap;
39+
import java.util.IdentityHashMap;
3840
import java.util.Iterator;
3941
import java.util.List;
4042
import java.util.Map;
@@ -58,6 +60,7 @@ abstract class DataNodeRequestSender {
5860
private final Map<DiscoveryNode, Semaphore> nodePermits = new HashMap<>();
5961
private final Map<ShardId, ShardFailure> shardFailures = ConcurrentCollections.newConcurrentMap();
6062
private final AtomicBoolean changed = new AtomicBoolean();
63+
private boolean reportedFailure = false; // guarded by sendingLock
6164

6265
DataNodeRequestSender(TransportService transportService, Executor esqlExecutor, CancellableTask rootTask) {
6366
this.transportService = transportService;
@@ -117,11 +120,14 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
117120
);
118121
}
119122
}
120-
if (shardFailures.values().stream().anyMatch(shardFailure -> shardFailure.fatal)) {
121-
for (var e : shardFailures.values()) {
122-
computeListener.acquireAvoid().onFailure(e.failure);
123-
}
123+
if (reportedFailure || shardFailures.values().stream().anyMatch(shardFailure -> shardFailure.fatal)) {
124+
reportedFailure = true;
125+
reportFailures(computeListener);
124126
} else {
127+
pendingShardIds.removeIf(shr -> {
128+
var failure = shardFailures.get(shr);
129+
return failure != null && failure.fatal;
130+
});
125131
var nodeRequests = selectNodeRequests(targetShards);
126132
for (NodeRequest request : nodeRequests) {
127133
sendOneNodeRequest(targetShards, computeListener, request);
@@ -136,6 +142,20 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
136142
}
137143
}
138144

145+
private void reportFailures(ComputeListener computeListener) {
146+
assert sendingLock.isHeldByCurrentThread();
147+
assert reportedFailure;
148+
Iterator<ShardFailure> it = shardFailures.values().iterator();
149+
Set<Exception> seen = Collections.newSetFromMap(new IdentityHashMap<>());
150+
while (it.hasNext()) {
151+
ShardFailure failure = it.next();
152+
if (seen.add(failure.failure)) {
153+
computeListener.acquireAvoid().onFailure(failure.failure);
154+
}
155+
it.remove();
156+
}
157+
}
158+
139159
private void sendOneNodeRequest(TargetShards targetShards, ComputeListener computeListener, NodeRequest request) {
140160
final ActionListener<List<DriverProfile>> listener = computeListener.acquireCompute();
141161
sendRequest(request.node, request.shardIds, request.aliasFilters, new NodeListener() {
@@ -148,7 +168,7 @@ void onAfter(List<DriverProfile> profiles) {
148168
@Override
149169
public void onResponse(DataNodeComputeResponse response) {
150170
// remove failures of successful shards
151-
for (ShardId shardId : targetShards.shardIds()) {
171+
for (ShardId shardId : request.shardIds()) {
152172
if (response.shardLevelFailures().containsKey(shardId) == false) {
153173
shardFailures.remove(shardId);
154174
}
@@ -250,6 +270,7 @@ private List<NodeRequest> selectNodeRequests(TargetShards targetShards) {
250270
final Iterator<ShardId> shardsIt = pendingShardIds.iterator();
251271
while (shardsIt.hasNext()) {
252272
ShardId shardId = shardsIt.next();
273+
assert shardFailures.get(shardId) == null || shardFailures.get(shardId).fatal == false;
253274
TargetShard shard = targetShards.getShard(shardId);
254275
Iterator<DiscoveryNode> nodesIt = shard.remainingNodes.iterator();
255276
DiscoveryNode selectedNode = null;

0 commit comments

Comments
 (0)