2323import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
2424import org .elasticsearch .common .util .concurrent .EsExecutors ;
2525import org .elasticsearch .compute .test .ComputeTestCase ;
26- import org .elasticsearch .core .TimeValue ;
2726import org .elasticsearch .index .Index ;
2827import org .elasticsearch .index .query .QueryBuilder ;
2928import org .elasticsearch .index .shard .ShardId ;
4140import java .io .IOException ;
4241import java .util .ArrayList ;
4342import java .util .Arrays ;
43+ import java .util .Collection ;
4444import java .util .Collections ;
4545import java .util .HashMap ;
4646import java .util .List ;
5959import static org .elasticsearch .cluster .node .DiscoveryNodeRole .DATA_FROZEN_NODE_ROLE ;
6060import static org .elasticsearch .cluster .node .DiscoveryNodeRole .DATA_HOT_NODE_ROLE ;
6161import static org .elasticsearch .cluster .node .DiscoveryNodeRole .DATA_WARM_NODE_ROLE ;
62+ import static org .elasticsearch .core .TimeValue .timeValueNanos ;
6263import static org .elasticsearch .xpack .esql .plugin .DataNodeRequestSender .NodeRequest ;
6364import static org .hamcrest .Matchers .anyOf ;
65+ import static org .hamcrest .Matchers .contains ;
66+ import static org .hamcrest .Matchers .containsInAnyOrder ;
6467import static org .hamcrest .Matchers .containsString ;
6568import static org .hamcrest .Matchers .empty ;
6669import static org .hamcrest .Matchers .equalTo ;
@@ -120,12 +123,12 @@ public void testOnePass() {
120123 );
121124 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
122125 var future = sendRequests (targetShards , randomBoolean (), -1 , (node , shardIds , aliasFilters , listener ) -> {
123- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
126+ sent .add (nodeRequest (node , shardIds ));
124127 runWithDelay (() -> listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ())));
125128 });
126129 safeGet (future );
127130 assertThat (sent .size (), equalTo (2 ));
128- assertThat (groupRequests ( sent , 2 ), equalTo ( Map . of (node1 , List . of ( shard1 , shard3 ), node2 , List . of ( shard2 , shard4 ) )));
131+ assertThat (sent , containsInAnyOrder ( nodeRequest (node1 , shard1 , shard3 ), nodeRequest ( node2 , shard2 , shard4 )));
129132 }
130133
131134 public void testMissingShards () {
@@ -163,7 +166,7 @@ public void testRetryThenSuccess() {
163166 );
164167 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
165168 var future = sendRequests (targetShards , randomBoolean (), -1 , (node , shardIds , aliasFilters , listener ) -> {
166- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
169+ sent .add (nodeRequest (node , shardIds ));
167170 Map <ShardId , Exception > failures = new HashMap <>();
168171 if (node .equals (node1 ) && shardIds .contains (shard5 )) {
169172 failures .put (shard5 , new IOException ("test" ));
@@ -179,10 +182,11 @@ public void testRetryThenSuccess() {
179182 throw new AssertionError (e );
180183 }
181184 assertThat (sent , hasSize (5 ));
182- var firstRound = groupRequests (sent , 3 );
183- assertThat (firstRound , equalTo (Map .of (node1 , List .of (shard1 , shard5 ), node4 , List .of (shard2 ), node2 , List .of (shard3 , shard4 ))));
184- var secondRound = groupRequests (sent , 2 );
185- assertThat (secondRound , equalTo (Map .of (node2 , List .of (shard2 ), node3 , List .of (shard5 ))));
185+ assertThat (
186+ take (sent , 3 ),
187+ containsInAnyOrder (nodeRequest (node1 , shard1 , shard5 ), nodeRequest (node4 , shard2 ), nodeRequest (node2 , shard3 , shard4 ))
188+ );
189+ assertThat (take (sent , 2 ), containsInAnyOrder (nodeRequest (node2 , shard2 ), nodeRequest (node3 , shard5 )));
186190 }
187191
188192 public void testRetryButFail () {
@@ -195,7 +199,7 @@ public void testRetryButFail() {
195199 );
196200 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
197201 var future = sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
198- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
202+ sent .add (nodeRequest (node , shardIds ));
199203 Map <ShardId , Exception > failures = new HashMap <>();
200204 if (shardIds .contains (shard5 )) {
201205 failures .put (shard5 , new IOException ("test failure for shard5" ));
@@ -206,22 +210,20 @@ public void testRetryButFail() {
206210 assertNotNull (ExceptionsHelper .unwrap (error , IOException .class ));
207211 // {node-1, node-2, node-4}, {node-3}, {node-2}
208212 assertThat (sent .size (), equalTo (5 ));
209- var firstRound = groupRequests (sent , 3 );
210- assertThat (firstRound , equalTo (Map .of (node1 , List .of (shard1 , shard5 ), node2 , List .of (shard3 , shard4 ), node4 , List .of (shard2 ))));
211- NodeRequest fourth = sent .remove ();
212- assertThat (fourth .node (), equalTo (node3 ));
213- assertThat (fourth .shardIds (), equalTo (List .of (shard5 )));
214- NodeRequest fifth = sent .remove ();
215- assertThat (fifth .node (), equalTo (node2 ));
216- assertThat (fifth .shardIds (), equalTo (List .of (shard5 )));
213+ assertThat (
214+ take (sent , 3 ),
215+ containsInAnyOrder (nodeRequest (node1 , shard1 , shard5 ), nodeRequest (node2 , shard3 , shard4 ), nodeRequest (node4 , shard2 ))
216+ );
217+ assertThat (take (sent , 1 ), containsInAnyOrder (nodeRequest (node3 , shard5 )));
218+ assertThat (take (sent , 1 ), containsInAnyOrder (nodeRequest (node2 , shard5 )));
217219 }
218220
219221 public void testDoNotRetryOnRequestLevelFailure () {
220222 var targetShards = List .of (targetShard (shard1 , node1 ), targetShard (shard2 , node2 ), targetShard (shard3 , node1 ));
221223 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
222224 AtomicBoolean failed = new AtomicBoolean ();
223225 var future = sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
224- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
226+ sent .add (nodeRequest (node , shardIds ));
225227 if (node1 .equals (node ) && failed .compareAndSet (false , true )) {
226228 runWithDelay (() -> listener .onFailure (new IOException ("test request level failure" ), true ));
227229 } else {
@@ -232,37 +234,35 @@ public void testDoNotRetryOnRequestLevelFailure() {
232234 assertNotNull (ExceptionsHelper .unwrap (exception , IOException .class ));
233235 // one round: {node-1, node-2}
234236 assertThat (sent .size (), equalTo (2 ));
235- var firstRound = groupRequests (sent , 2 );
236- assertThat (firstRound , equalTo (Map .of (node1 , List .of (shard1 , shard3 ), node2 , List .of (shard2 ))));
237+ assertThat (sent , containsInAnyOrder (nodeRequest (node1 , shard1 , shard3 ), nodeRequest (node2 , shard2 )));
237238 }
238239
239240 public void testAllowPartialResults () {
240241 var targetShards = List .of (targetShard (shard1 , node1 ), targetShard (shard2 , node2 ), targetShard (shard3 , node1 , node2 ));
241242 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
242243 AtomicBoolean failed = new AtomicBoolean ();
243244 var future = sendRequests (targetShards , true , -1 , (node , shardIds , aliasFilters , listener ) -> {
244- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
245+ sent .add (nodeRequest (node , shardIds ));
245246 if (node1 .equals (node ) && failed .compareAndSet (false , true )) {
246247 runWithDelay (() -> listener .onFailure (new IOException ("test request level failure" ), true ));
247248 } else {
248249 runWithDelay (() -> listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ())));
249250 }
250251 });
251- ComputeResponse resp = safeGet (future );
252+ var response = safeGet (future );
253+ assertThat (response .totalShards , equalTo (3 ));
254+ assertThat (response .failedShards , equalTo (2 ));
255+ assertThat (response .successfulShards , equalTo (1 ));
252256 // one round: {node-1, node-2}
253257 assertThat (sent .size (), equalTo (2 ));
254- var firstRound = groupRequests (sent , 2 );
255- assertThat (firstRound , equalTo (Map .of (node1 , List .of (shard1 , shard3 ), node2 , List .of (shard2 ))));
256- assertThat (resp .totalShards , equalTo (3 ));
257- assertThat (resp .failedShards , equalTo (2 ));
258- assertThat (resp .successfulShards , equalTo (1 ));
258+ assertThat (sent , containsInAnyOrder (nodeRequest (node1 , shard1 , shard3 ), nodeRequest (node2 , shard2 )));
259259 }
260260
261261 public void testNonFatalErrorIsRetriedOnAnotherShard () {
262262 var targetShards = List .of (targetShard (shard1 , node1 , node2 ));
263263 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
264264 var response = safeGet (sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
265- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
265+ sent .add (nodeRequest (node , shardIds ));
266266 if (Objects .equals (node1 , node )) {
267267 runWithDelay (() -> listener .onFailure (new RuntimeException ("test request level non fatal failure" ), false ));
268268 } else {
@@ -279,7 +279,7 @@ public void testNonFatalFailedOnAllNodes() {
279279 var targetShards = List .of (targetShard (shard1 , node1 , node2 ));
280280 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
281281 var future = sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
282- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
282+ sent .add (nodeRequest (node , shardIds ));
283283 runWithDelay (() -> listener .onFailure (new RuntimeException ("test request level non fatal failure" ), false ));
284284 });
285285 expectThrows (RuntimeException .class , equalTo ("test request level non fatal failure" ), future ::actionGet );
@@ -290,7 +290,7 @@ public void testDoNotRetryCircuitBreakerException() {
290290 var targetShards = List .of (targetShard (shard1 , node1 , node2 ));
291291 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
292292 var future = sendRequests (targetShards , false , -1 , (node , shardIds , aliasFilters , listener ) -> {
293- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
293+ sent .add (nodeRequest (node , shardIds ));
294294 runWithDelay (() -> listener .onFailure (new CircuitBreakingException ("cbe" , randomFrom (Durability .values ())), false ));
295295 });
296296 expectThrows (CircuitBreakingException .class , equalTo ("cbe" ), future ::actionGet );
@@ -321,7 +321,7 @@ public void testLimitConcurrentNodes() {
321321 }
322322 }
323323
324- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
324+ sent .add (nodeRequest (node , shardIds ));
325325 runWithDelay (() -> {
326326 concurrentRequests .decrementAndGet ();
327327 listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ()));
@@ -364,7 +364,7 @@ public void testSkipRemovesPriorNonFatalErrors() {
364364
365365 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
366366 var response = safeGet (sendRequests (targetShards , randomBoolean (), 1 , (node , shardIds , aliasFilters , listener ) -> {
367- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
367+ sent .add (nodeRequest (node , shardIds ));
368368 runWithDelay (() -> {
369369 if (Objects .equals (node .getId (), node1 .getId ()) && shardIds .equals (List .of (shard1 ))) {
370370 listener .onFailure (new RuntimeException ("test request level non fatal failure" ), false );
@@ -406,29 +406,38 @@ public void testQueryHotShardsFirstWhenIlmMovesShard() {
406406 );
407407 var sent = ConcurrentCollections .<NodeRequest >newQueue ();
408408 safeGet (sendRequests (targetShards , randomBoolean (), -1 , (node , shardIds , aliasFilters , listener ) -> {
409- sent .add (new NodeRequest (node , shardIds , aliasFilters ));
409+ sent .add (nodeRequest (node , shardIds ));
410410 runWithDelay (() -> listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ())));
411411 }));
412- assertThat (groupRequests (sent , 1 ), equalTo ( Map . of (node1 , List . of ( shard1 ) )));
413- assertThat (groupRequests (sent , 1 ), anyOf (equalTo ( Map . of (node2 , List . of ( shard2 ))), equalTo ( Map . of (warmNode2 , List . of ( shard2 ) ))));
412+ assertThat (take (sent , 1 ), containsInAnyOrder ( nodeRequest (node1 , shard1 )));
413+ assertThat (take (sent , 1 ), anyOf (contains ( nodeRequest (node2 , shard2 )), contains ( nodeRequest (warmNode2 , shard2 ))));
414414 }
415415
416416 static DataNodeRequestSender .TargetShard targetShard (ShardId shardId , DiscoveryNode ... nodes ) {
417417 return new DataNodeRequestSender .TargetShard (shardId , new ArrayList <>(Arrays .asList (nodes )), null );
418418 }
419419
420- static Map <DiscoveryNode , List <ShardId >> groupRequests (Queue <NodeRequest > sent , int limit ) {
421- Map <DiscoveryNode , List <ShardId >> map = new HashMap <>();
420+ static DataNodeRequestSender .NodeRequest nodeRequest (DiscoveryNode node , ShardId ... shardIds ) {
421+ return nodeRequest (node , Arrays .asList (shardIds ));
422+ }
423+
424+ static DataNodeRequestSender .NodeRequest nodeRequest (DiscoveryNode node , List <ShardId > shardIds ) {
425+ var copy = new ArrayList <>(shardIds );
426+ Collections .sort (copy );
427+ return new NodeRequest (node , copy , Map .of ());
428+ }
429+
430+ static <T > Collection <T > take (Queue <T > queue , int limit ) {
431+ var result = new ArrayList <T >(limit );
422432 for (int i = 0 ; i < limit ; i ++) {
423- NodeRequest r = sent .remove ();
424- assertNull (map .put (r .node (), r .shardIds ().stream ().sorted ().toList ()));
433+ result .add (queue .remove ());
425434 }
426- return map ;
435+ return result ;
427436 }
428437
429438 void runWithDelay (Runnable runnable ) {
430439 if (randomBoolean ()) {
431- threadPool .schedule (runnable , TimeValue . timeValueNanos (between (0 , 5000 )), executor );
440+ threadPool .schedule (runnable , timeValueNanos (between (0 , 5000 )), executor );
432441 } else {
433442 executor .execute (runnable );
434443 }
0 commit comments