2323import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
2424import org .elasticsearch .common .util .concurrent .EsExecutors ;
2525import org .elasticsearch .compute .test .ComputeTestCase ;
26- import org .elasticsearch .core .TimeValue ;
2726import org .elasticsearch .index .Index ;
2827import org .elasticsearch .index .query .QueryBuilder ;
2928import org .elasticsearch .index .shard .ShardId ;
3029import org .elasticsearch .search .internal .AliasFilter ;
3130import org .elasticsearch .tasks .CancellableTask ;
32- import org .elasticsearch .tasks .Task ;
3331import org .elasticsearch .tasks .TaskId ;
3432import org .elasticsearch .test .transport .MockTransportService ;
3533import org .elasticsearch .threadpool .FixedExecutorBuilder ;
4139import java .io .IOException ;
4240import java .util .ArrayList ;
4341import java .util .Arrays ;
42+ import java .util .Collection ;
4443import java .util .Collections ;
4544import java .util .HashMap ;
4645import java .util .List ;
5958import static org .elasticsearch .cluster .node .DiscoveryNodeRole .DATA_FROZEN_NODE_ROLE ;
6059import static org .elasticsearch .cluster .node .DiscoveryNodeRole .DATA_HOT_NODE_ROLE ;
6160import static org .elasticsearch .cluster .node .DiscoveryNodeRole .DATA_WARM_NODE_ROLE ;
61+ import static org .elasticsearch .core .TimeValue .timeValueNanos ;
6262import static org .elasticsearch .xpack .esql .plugin .DataNodeRequestSender .NodeRequest ;
6363import static org .hamcrest .Matchers .anyOf ;
64+ import static org .hamcrest .Matchers .contains ;
65+ import static org .hamcrest .Matchers .containsInAnyOrder ;
6466import static org .hamcrest .Matchers .containsString ;
6567import static org .hamcrest .Matchers .empty ;
6668import static org .hamcrest .Matchers .equalTo ;
@@ -120,12 +122,12 @@ public void testOnePass() {
120122 );
121123 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
122124 var future = sendRequests (targetShards , randomBoolean (), -1 , (node , shardIds , aliasFilters , listener ) -> {
123- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
125+ sent .add (nodeRequest (node , shardIds ));
124126 runWithDelay (() -> listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ())));
125127 });
126128 safeGet (future );
127129 assertThat (sent .size (), equalTo (2 ));
128- assertThat (groupRequests ( sent , 2 ), equalTo ( Map . of (node1 , List . of ( shard1 , shard3 ), node2 , List . of ( shard2 , shard4 ) )));
130+ assertThat (sent , containsInAnyOrder ( nodeRequest (node1 , shard1 , shard3 ), nodeRequest ( node2 , shard2 , shard4 )));
129131 }
130132
131133 public void testMissingShards () {
@@ -163,7 +165,7 @@ public void testRetryThenSuccess() {
163165 );
164166 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
165167 var future = sendRequests (targetShards , randomBoolean (), -1 , (node , shardIds , aliasFilters , listener ) -> {
166- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
168+ sent .add (nodeRequest (node , shardIds ));
167169 Map <ShardId , Exception > failures = new HashMap <>();
168170 if (node .equals (node1 ) && shardIds .contains (shard5 )) {
169171 failures .put (shard5 , new IOException ("test" ));
@@ -179,10 +181,11 @@ public void testRetryThenSuccess() {
179181 throw new AssertionError (e );
180182 }
181183 assertThat (sent , hasSize (5 ));
182- var firstRound = groupRequests (sent , 3 );
183- assertThat (firstRound , equalTo (Map .of (node1 , List .of (shard1 , shard5 ), node4 , List .of (shard2 ), node2 , List .of (shard3 , shard4 ))));
184- var secondRound = groupRequests (sent , 2 );
185- assertThat (secondRound , equalTo (Map .of (node2 , List .of (shard2 ), node3 , List .of (shard5 ))));
184+ assertThat (
185+ take (sent , 3 ),
186+ containsInAnyOrder (nodeRequest (node1 , shard1 , shard5 ), nodeRequest (node4 , shard2 ), nodeRequest (node2 , shard3 , shard4 ))
187+ );
188+ assertThat (take (sent , 2 ), containsInAnyOrder (nodeRequest (node2 , shard2 ), nodeRequest (node3 , shard5 )));
186189 }
187190
188191 public void testRetryButFail () {
@@ -195,7 +198,7 @@ public void testRetryButFail() {
195198 );
196199 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
197200 var future = sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
198- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
201+ sent .add (nodeRequest (node , shardIds ));
199202 Map <ShardId , Exception > failures = new HashMap <>();
200203 if (shardIds .contains (shard5 )) {
201204 failures .put (shard5 , new IOException ("test failure for shard5" ));
@@ -206,22 +209,20 @@ public void testRetryButFail() {
206209 assertNotNull (ExceptionsHelper .unwrap (error , IOException .class ));
207210 // {node-1, node-2, node-4}, {node-3}, {node-2}
208211 assertThat (sent .size (), equalTo (5 ));
209- var firstRound = groupRequests (sent , 3 );
210- assertThat (firstRound , equalTo (Map .of (node1 , List .of (shard1 , shard5 ), node2 , List .of (shard3 , shard4 ), node4 , List .of (shard2 ))));
211- NodeRequest fourth = sent .remove ();
212- assertThat (fourth .node (), equalTo (node3 ));
213- assertThat (fourth .shardIds (), equalTo (List .of (shard5 )));
214- NodeRequest fifth = sent .remove ();
215- assertThat (fifth .node (), equalTo (node2 ));
216- assertThat (fifth .shardIds (), equalTo (List .of (shard5 )));
212+ assertThat (
213+ take (sent , 3 ),
214+ containsInAnyOrder (nodeRequest (node1 , shard1 , shard5 ), nodeRequest (node2 , shard3 , shard4 ), nodeRequest (node4 , shard2 ))
215+ );
216+ assertThat (take (sent , 1 ), containsInAnyOrder (nodeRequest (node3 , shard5 )));
217+ assertThat (take (sent , 1 ), containsInAnyOrder (nodeRequest (node2 , shard5 )));
217218 }
218219
219220 public void testDoNotRetryOnRequestLevelFailure () {
220221 var targetShards = List .of (targetShard (shard1 , node1 ), targetShard (shard2 , node2 ), targetShard (shard3 , node1 ));
221222 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
222223 AtomicBoolean failed = new AtomicBoolean ();
223224 var future = sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
224- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
225+ sent .add (nodeRequest (node , shardIds ));
225226 if (node1 .equals (node ) && failed .compareAndSet (false , true )) {
226227 runWithDelay (() -> listener .onFailure (new IOException ("test request level failure" ), true ));
227228 } else {
@@ -232,37 +233,35 @@ public void testDoNotRetryOnRequestLevelFailure() {
232233 assertNotNull (ExceptionsHelper .unwrap (exception , IOException .class ));
233234 // one round: {node-1, node-2}
234235 assertThat (sent .size (), equalTo (2 ));
235- var firstRound = groupRequests (sent , 2 );
236- assertThat (firstRound , equalTo (Map .of (node1 , List .of (shard1 , shard3 ), node2 , List .of (shard2 ))));
236+ assertThat (sent , containsInAnyOrder (nodeRequest (node1 , shard1 , shard3 ), nodeRequest (node2 , shard2 )));
237237 }
238238
239239 public void testAllowPartialResults () {
240240 var targetShards = List .of (targetShard (shard1 , node1 ), targetShard (shard2 , node2 ), targetShard (shard3 , node1 , node2 ));
241241 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
242242 AtomicBoolean failed = new AtomicBoolean ();
243243 var future = sendRequests (targetShards , true , -1 , (node , shardIds , aliasFilters , listener ) -> {
244- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
244+ sent .add (nodeRequest (node , shardIds ));
245245 if (node1 .equals (node ) && failed .compareAndSet (false , true )) {
246246 runWithDelay (() -> listener .onFailure (new IOException ("test request level failure" ), true ));
247247 } else {
248248 runWithDelay (() -> listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ())));
249249 }
250250 });
251- ComputeResponse resp = safeGet (future );
251+ var response = safeGet (future );
252+ assertThat (response .totalShards , equalTo (3 ));
253+ assertThat (response .failedShards , equalTo (2 ));
254+ assertThat (response .successfulShards , equalTo (1 ));
252255 // one round: {node-1, node-2}
253256 assertThat (sent .size (), equalTo (2 ));
254- var firstRound = groupRequests (sent , 2 );
255- assertThat (firstRound , equalTo (Map .of (node1 , List .of (shard1 , shard3 ), node2 , List .of (shard2 ))));
256- assertThat (resp .totalShards , equalTo (3 ));
257- assertThat (resp .failedShards , equalTo (2 ));
258- assertThat (resp .successfulShards , equalTo (1 ));
257+ assertThat (sent , containsInAnyOrder (nodeRequest (node1 , shard1 , shard3 ), nodeRequest (node2 , shard2 )));
259258 }
260259
261260 public void testNonFatalErrorIsRetriedOnAnotherShard () {
262261 var targetShards = List .of (targetShard (shard1 , node1 , node2 ));
263262 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
264263 var response = safeGet (sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
265- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
264+ sent .add (nodeRequest (node , shardIds ));
266265 if (Objects .equals (node1 , node )) {
267266 runWithDelay (() -> listener .onFailure (new RuntimeException ("test request level non fatal failure" ), false ));
268267 } else {
@@ -279,7 +278,7 @@ public void testNonFatalFailedOnAllNodes() {
279278 var targetShards = List .of (targetShard (shard1 , node1 , node2 ));
280279 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
281280 var future = sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
282- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
281+ sent .add (nodeRequest (node , shardIds ));
283282 runWithDelay (() -> listener .onFailure (new RuntimeException ("test request level non fatal failure" ), false ));
284283 });
285284 expectThrows (RuntimeException .class , equalTo ("test request level non fatal failure" ), future ::actionGet );
@@ -290,7 +289,7 @@ public void testDoNotRetryCircuitBreakerException() {
290289 var targetShards = List .of (targetShard (shard1 , node1 , node2 ));
291290 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
292291 var future = sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
293- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
292+ sent .add (nodeRequest (node , shardIds ));
294293 runWithDelay (() -> listener .onFailure (new CircuitBreakingException ("cbe" , randomFrom (Durability .values ())), false ));
295294 });
296295 expectThrows (CircuitBreakingException .class , equalTo ("cbe" ), future ::actionGet );
@@ -321,7 +320,7 @@ public void testLimitConcurrentNodes() {
321320 }
322321 }
323322
324- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
323+ sent .add (nodeRequest (node , shardIds ));
325324 runWithDelay (() -> {
326325 concurrentRequests .decrementAndGet ();
327326 listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ()));
@@ -364,7 +363,7 @@ public void testSkipRemovesPriorNonFatalErrors() {
364363
365364 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
366365 var response = safeGet (sendRequests (targetShards , randomBoolean (), 1 , (node , shardIds , aliasFilters , listener ) -> {
367- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
366+ sent .add (nodeRequest (node , shardIds ));
368367 runWithDelay (() -> {
369368 if (Objects .equals (node .getId (), node1 .getId ()) && shardIds .equals (List .of (shard1 ))) {
370369 listener .onFailure (new RuntimeException ("test request level non fatal failure" ), false );
@@ -406,29 +405,38 @@ public void testQueryHotShardsFirstWhenIlmMovesShard() {
406405 );
407406 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
408407 safeGet (sendRequests (targetShards , randomBoolean (), -1 , (node , shardIds , aliasFilters , listener ) -> {
409- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
408+ sent .add (nodeRequest (node , shardIds ));
410409 runWithDelay (() -> listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ())));
411410 }));
412- assertThat (groupRequests (sent , 1 ), equalTo ( Map . of (node1 , List . of ( shard1 ) )));
413- assertThat (groupRequests (sent , 1 ), anyOf (equalTo ( Map . of (node2 , List . of ( shard2 ))), equalTo ( Map . of (warmNode2 , List . of ( shard2 ) ))));
411+ assertThat (take (sent , 1 ), containsInAnyOrder ( nodeRequest (node1 , shard1 )));
412+ assertThat (take (sent , 1 ), anyOf (contains ( nodeRequest (node2 , shard2 )), contains ( nodeRequest (warmNode2 , shard2 ))));
414413 }
415414
416415 static DataNodeRequestSender .TargetShard targetShard (ShardId shardId , DiscoveryNode ... nodes ) {
417416 return new DataNodeRequestSender .TargetShard (shardId , new ArrayList <>(Arrays .asList (nodes )), null );
418417 }
419418
420- static Map <DiscoveryNode , List <ShardId >> groupRequests (Queue <NodeRequest > sent , int limit ) {
421- Map <DiscoveryNode , List <ShardId >> map = new HashMap <>();
419+ static DataNodeRequestSender .NodeRequest nodeRequest (DiscoveryNode node , ShardId ... shardIds ) {
420+ return nodeRequest (node , Arrays .asList (shardIds ));
421+ }
422+
423+ static DataNodeRequestSender .NodeRequest nodeRequest (DiscoveryNode node , List <ShardId > shardIds ) {
424+ var copy = new ArrayList <>(shardIds );
425+ Collections .sort (copy );
426+ return new NodeRequest (node , copy , Map .of ());
427+ }
428+
429+ static <T > Collection <T > take (Queue <T > queue , int limit ) {
430+ var result = new ArrayList <T >(limit );
422431 for (int i = 0 ; i < limit ; i ++) {
423- NodeRequest r = sent .remove ();
424- assertNull (map .put (r .node (), r .shardIds ().stream ().sorted ().toList ()));
432+ result .add (queue .remove ());
425433 }
426- return map ;
434+ return result ;
427435 }
428436
429437 void runWithDelay (Runnable runnable ) {
430438 if (randomBoolean ()) {
431- threadPool .schedule (runnable , TimeValue . timeValueNanos (between (0 , 5000 )), executor );
439+ threadPool .schedule (runnable , timeValueNanos (between (0 , 5000 )), executor );
432440 } else {
433441 executor .execute (runnable );
434442 }
@@ -465,8 +473,6 @@ PlainActionFuture<ComputeResponse> sendRequests(
465473 ) {
466474 @ Override
467475 void searchShards (
468- Task parentTask ,
469- String clusterAlias ,
470476 QueryBuilder filter ,
471477 Set <String > concreteIndices ,
472478 OriginalIndices originalIndices ,
@@ -477,7 +483,6 @@ void searchShards(
477483 shards .size (),
478484 0
479485 );
480- assertSame (parentTask , task );
481486 runWithDelay (() -> listener .onResponse (targetShards ));
482487 }
483488
@@ -492,7 +497,6 @@ protected void sendRequest(
492497 }
493498 };
494499 requestSender .startComputeOnDataNodes (
495- "" ,
496500 Set .of (randomAlphaOfLength (10 )),
497501 new OriginalIndices (new String [0 ], SearchRequest .DEFAULT_INDICES_OPTIONS ),
498502 null ,
0 commit comments