3333import org .elasticsearch .transport .TransportException ;
3434import org .elasticsearch .transport .TransportRequestOptions ;
3535import org .elasticsearch .transport .TransportService ;
36+ import org .elasticsearch .xpack .esql .action .EsqlResolveNodesAction ;
37+ import org .elasticsearch .xpack .esql .action .EsqlResolveNodesAction .ResolveNodesRequest ;
38+ import org .elasticsearch .xpack .esql .action .EsqlResolveNodesAction .ResolveNodesResponse ;
3639import org .elasticsearch .xpack .esql .action .EsqlSearchShardsAction ;
3740
3841import java .util .ArrayList ;
5255import java .util .concurrent .atomic .AtomicBoolean ;
5356import java .util .concurrent .atomic .AtomicInteger ;
5457import java .util .concurrent .locks .ReentrantLock ;
55- import java .util .function .Predicate ;
5658
59+ import static java .util .HashMap .newHashMap ;
5760import static org .elasticsearch .core .TimeValue .timeValueNanos ;
5861
5962/**
@@ -118,28 +121,22 @@ abstract class DataNodeRequestSender {
118121
119122 final void startComputeOnDataNodes (Set <String > concreteIndices , Runnable runOnTaskFailure , ActionListener <ComputeResponse > listener ) {
120123 final long startTimeInNanos = System .nanoTime ();
121- searchShards (
122- originalIndices .indices (),
123- shardId -> concreteIndices .contains (shardId .getIndexName ()),
124- ActionListener .wrap (targetShards -> {
125- try (
126- var computeListener = new ComputeListener (transportService .getThreadPool (), runOnTaskFailure , listener .map (profiles -> {
127- return new ComputeResponse (
128- profiles ,
129- timeValueNanos (System .nanoTime () - startTimeInNanos ),
130- targetShards .totalShards (),
131- targetShards .totalShards () - shardFailures .size () - skippedShards .get (),
132- targetShards .skippedShards () + skippedShards .get (),
133- shardFailures .size (),
134- selectFailures ()
135- );
136- }))
137- ) {
138- pendingShardIds .addAll (order (targetShards ));
139- trySendingRequestsForPendingShards (targetShards , computeListener );
140- }
141- }, listener ::onFailure )
142- );
124+ searchShards (concreteIndices , ActionListener .wrap (targetShards -> {
125+ try (var computeListener = new ComputeListener (transportService .getThreadPool (), runOnTaskFailure , listener .map (profiles -> {
126+ return new ComputeResponse (
127+ profiles ,
128+ timeValueNanos (System .nanoTime () - startTimeInNanos ),
129+ targetShards .totalShards (),
130+ targetShards .totalShards () - shardFailures .size () - skippedShards .get (),
131+ targetShards .skippedShards () + skippedShards .get (),
132+ shardFailures .size (),
133+ selectFailures ()
134+ );
135+ }))) {
136+ pendingShardIds .addAll (order (targetShards ));
137+ trySendingRequestsForPendingShards (targetShards , computeListener );
138+ }
139+ }, listener ::onFailure ));
143140 }
144141
145142 private static List <ShardId > order (TargetShards targetShards ) {
@@ -256,12 +253,11 @@ void onAfter(List<DriverProfile> profiles) {
256253 concurrentRequests .release ();
257254 }
258255
259- if (pendingRetries .isEmpty () == false && remainingTargetShardSearchAttempts .decrementAndGet () > 0 ) {
256+ if (pendingRetries .isEmpty () == false && remainingTargetShardSearchAttempts .getAndDecrement () > 0 ) {
260257 ongoingTargetShardResolutionAttempts .incrementAndGet ();
261- var indices = pendingRetries .stream ().map (ShardId ::getIndexName ).distinct ().toArray (String []::new );
262- searchShards (indices , pendingRetries ::contains , computeListener .acquireAvoid ().delegateFailure ((l , newSearchShards ) -> {
263- for (var entry : newSearchShards .shards .entrySet ()) {
264- targetShards .shards .get (entry .getKey ()).remainingNodes .addAll (entry .getValue ().remainingNodes );
258+ resolveShards (pendingRetries , computeListener .acquireAvoid ().delegateFailure ((l , resolutions ) -> {
259+ for (var entry : resolutions .entrySet ()) {
260+ targetShards .shards .get (entry .getKey ()).remainingNodes .addAll (entry .getValue ());
265261 }
266262 ongoingTargetShardResolutionAttempts .decrementAndGet ();
267263 trySendingRequestsForPendingShards (targetShards , computeListener );
@@ -356,7 +352,7 @@ private void trackShardLevelFailure(ShardId shardId, boolean fatal, Exception or
356352 }
357353
358354 /**
359- * Result from {@link #searchShards(String[], Predicate , ActionListener)} where can_match is performed to
355+ * Result from {@link #searchShards(Set , ActionListener)} where can_match is performed to
360356 * determine what shards can be skipped and which target nodes are needed for running the ES|QL query
361357 *
362358 * @param shards List of target shards to perform the ES|QL query on
@@ -446,18 +442,18 @@ private List<NodeRequest> selectNodeRequests(TargetShards targetShards) {
446442 * Ideally, the search_shards API should be called before the field-caps API; however, this can lead
447443 * to a situation where the column structure (i.e., matched data types) differs depending on the query.
448444 */
449- void searchShards (String [] indices , Predicate < ShardId > predicate , ActionListener <TargetShards > listener ) {
445+ void searchShards (Set < String > concreteIndices , ActionListener <TargetShards > listener ) {
450446 ActionListener <SearchShardsResponse > searchShardsListener = listener .map (resp -> {
451- Map <String , DiscoveryNode > nodes = new HashMap <>( );
447+ Map <String , DiscoveryNode > nodes = newHashMap ( resp . getNodes (). size () );
452448 for (DiscoveryNode node : resp .getNodes ()) {
453449 nodes .put (node .getId (), node );
454450 }
455451 int totalShards = 0 ;
456452 int skippedShards = 0 ;
457- Map <ShardId , TargetShard > shards = new HashMap <>( );
453+ Map <ShardId , TargetShard > shards = newHashMap ( resp . getGroups (). size () );
458454 for (SearchShardsGroup group : resp .getGroups ()) {
459455 var shardId = group .shardId ();
460- if (predicate . test (shardId ) == false ) {
456+ if (concreteIndices . contains (shardId . getIndexName () ) == false ) {
461457 continue ;
462458 }
463459 totalShards ++;
@@ -475,7 +471,7 @@ void searchShards(String[] indices, Predicate<ShardId> predicate, ActionListener
475471 return new TargetShards (shards , totalShards , skippedShards );
476472 });
477473 var searchShardsRequest = new SearchShardsRequest (
478- indices ,
474+ originalIndices . indices () ,
479475 originalIndices .indicesOptions (),
480476 requestFilter ,
481477 null ,
@@ -492,4 +488,15 @@ void searchShards(String[] indices, Predicate<ShardId> predicate, ActionListener
492488 new ActionListenerResponseHandler <>(searchShardsListener , SearchShardsResponse ::new , esqlExecutor )
493489 );
494490 }
491+
492+ void resolveShards (Set <ShardId > shardIds , ActionListener <Map <ShardId , List <DiscoveryNode >>> listener ) {
493+ transportService .sendChildRequest (
494+ transportService .getLocalNode (),
495+ EsqlResolveNodesAction .TYPE .name (),
496+ new ResolveNodesRequest (shardIds ),
497+ rootTask ,
498+ TransportRequestOptions .EMPTY ,
499+ new ActionListenerResponseHandler <>(listener .map (ResolveNodesResponse ::nodes ), ResolveNodesResponse ::new , esqlExecutor )
500+ );
501+ }
495502}
0 commit comments