3434import org .elasticsearch .xpack .esql .action .EsqlSearchShardsAction ;
3535
3636import java .util .ArrayList ;
37+ import java .util .Collections ;
3738import java .util .HashMap ;
39+ import java .util .IdentityHashMap ;
3840import java .util .Iterator ;
3941import java .util .List ;
4042import java .util .Map ;
@@ -58,6 +60,7 @@ abstract class DataNodeRequestSender {
5860 private final Map <DiscoveryNode , Semaphore > nodePermits = new HashMap <>();
5961 private final Map <ShardId , ShardFailure > shardFailures = ConcurrentCollections .newConcurrentMap ();
6062 private final AtomicBoolean changed = new AtomicBoolean ();
63+ private boolean reportedFailure = false ; // guarded by sendingLock
6164
6265 DataNodeRequestSender (TransportService transportService , Executor esqlExecutor , CancellableTask rootTask ) {
6366 this .transportService = transportService ;
@@ -117,11 +120,14 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
117120 );
118121 }
119122 }
120- if (shardFailures .values ().stream ().anyMatch (shardFailure -> shardFailure .fatal )) {
121- for (var e : shardFailures .values ()) {
122- computeListener .acquireAvoid ().onFailure (e .failure );
123- }
123+ if (reportedFailure || shardFailures .values ().stream ().anyMatch (shardFailure -> shardFailure .fatal )) {
124+ reportedFailure = true ;
125+ reportFailures (computeListener );
124126 } else {
127+ pendingShardIds .removeIf (shr -> {
128+ var failure = shardFailures .get (shr );
129+ return failure != null && failure .fatal ;
130+ });
125131 var nodeRequests = selectNodeRequests (targetShards );
126132 for (NodeRequest request : nodeRequests ) {
127133 sendOneNodeRequest (targetShards , computeListener , request );
@@ -136,6 +142,20 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
136142 }
137143 }
138144
145+ private void reportFailures (ComputeListener computeListener ) {
146+ assert sendingLock .isHeldByCurrentThread ();
147+ assert reportedFailure ;
148+ Iterator <ShardFailure > it = shardFailures .values ().iterator ();
149+ Set <Exception > seen = Collections .newSetFromMap (new IdentityHashMap <>());
150+ while (it .hasNext ()) {
151+ ShardFailure failure = it .next ();
152+ if (seen .add (failure .failure )) {
153+ computeListener .acquireAvoid ().onFailure (failure .failure );
154+ }
155+ it .remove ();
156+ }
157+ }
158+
139159 private void sendOneNodeRequest (TargetShards targetShards , ComputeListener computeListener , NodeRequest request ) {
140160 final ActionListener <List <DriverProfile >> listener = computeListener .acquireCompute ();
141161 sendRequest (request .node , request .shardIds , request .aliasFilters , new NodeListener () {
@@ -148,7 +168,7 @@ void onAfter(List<DriverProfile> profiles) {
148168 @ Override
149169 public void onResponse (DataNodeComputeResponse response ) {
150170 // remove failures of successful shards
151- for (ShardId shardId : targetShards .shardIds ()) {
171+ for (ShardId shardId : request .shardIds ()) {
152172 if (response .shardLevelFailures ().containsKey (shardId ) == false ) {
153173 shardFailures .remove (shardId );
154174 }
@@ -250,6 +270,7 @@ private List<NodeRequest> selectNodeRequests(TargetShards targetShards) {
250270 final Iterator <ShardId > shardsIt = pendingShardIds .iterator ();
251271 while (shardsIt .hasNext ()) {
252272 ShardId shardId = shardsIt .next ();
273+ assert shardFailures .get (shardId ) == null || shardFailures .get (shardId ).fatal == false ;
253274 TargetShard shard = targetShards .getShard (shardId );
254275 Iterator <DiscoveryNode > nodesIt = shard .remainingNodes .iterator ();
255276 DiscoveryNode selectedNode = null ;
0 commit comments