4545import java .util .concurrent .Executor ;
4646import java .util .concurrent .TimeUnit ;
4747import java .util .concurrent .atomic .AtomicBoolean ;
48+ import java .util .concurrent .atomic .AtomicInteger ;
4849import java .util .function .Function ;
4950import java .util .stream .Collectors ;
5051
@@ -85,14 +86,15 @@ public void setThreadPool() {
8586 }
8687
8788 @ After
88- public void shutdownThreadPool () throws Exception {
89+ public void shutdownThreadPool () {
8990 terminate (threadPool );
9091 }
9192
9293 public void testEmpty () {
9394 var future = sendRequests (
9495 List .of (),
9596 randomBoolean (),
97+ 10 ,
9698 (node , shardIds , aliasFilters , listener ) -> fail ("expect no data-node request is sent" )
9799 );
98100 var resp = safeGet (future );
@@ -107,10 +109,9 @@ public void testOnePass() {
107109 targetShard (shard4 , node2 , node3 )
108110 );
109111 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
110- var future = sendRequests (targetShards , randomBoolean (), (node , shardIds , aliasFilters , listener ) -> {
112+ var future = sendRequests (targetShards , randomBoolean (), 10 , (node , shardIds , aliasFilters , listener ) -> {
111113 sent .add (new NodeRequest (node , shardIds , aliasFilters ));
112- var resp = new DataNodeComputeResponse (List .of (), Map .of ());
113- runWithDelay (() -> listener .onResponse (resp ));
114+ runWithDelay (() -> listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ())));
114115 });
115116 safeGet (future );
116117 assertThat (sent .size (), equalTo (2 ));
@@ -120,15 +121,15 @@ public void testOnePass() {
120121 public void testMissingShards () {
121122 {
122123 var targetShards = List .of (targetShard (shard1 , node1 ), targetShard (shard3 ), targetShard (shard4 , node2 , node3 ));
123- var future = sendRequests (targetShards , false , (node , shardIds , aliasFilters , listener ) -> {
124+ var future = sendRequests (targetShards , false , 10 , (node , shardIds , aliasFilters , listener ) -> {
124125 fail ("expect no data-node request is sent when target shards are missing" );
125126 });
126127 var error = expectThrows (NoShardAvailableActionException .class , future ::actionGet );
127128 assertThat (error .getMessage (), containsString ("no shard copies found" ));
128129 }
129130 {
130131 var targetShards = List .of (targetShard (shard1 , node1 ), targetShard (shard3 ), targetShard (shard4 , node2 , node3 ));
131- var future = sendRequests (targetShards , true , (node , shardIds , aliasFilters , listener ) -> {
132+ var future = sendRequests (targetShards , true , 10 , (node , shardIds , aliasFilters , listener ) -> {
132133 assertThat (shard3 , not (in (shardIds )));
133134 runWithDelay (() -> listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ())));
134135 });
@@ -148,7 +149,7 @@ public void testRetryThenSuccess() {
148149 targetShard (shard5 , node1 , node3 , node2 )
149150 );
150151 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
151- var future = sendRequests (targetShards , randomBoolean (), (node , shardIds , aliasFilters , listener ) -> {
152+ var future = sendRequests (targetShards , randomBoolean (), 10 , (node , shardIds , aliasFilters , listener ) -> {
152153 sent .add (new NodeRequest (node , shardIds , aliasFilters ));
153154 Map <ShardId , Exception > failures = new HashMap <>();
154155 if (node .equals (node1 ) && shardIds .contains (shard5 )) {
@@ -180,7 +181,7 @@ public void testRetryButFail() {
180181 targetShard (shard5 , node1 , node3 , node2 )
181182 );
182183 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
183- var future = sendRequests (targetShards , false , (node , shardIds , aliasFilters , listener ) -> {
184+ var future = sendRequests (targetShards , false , 10 , (node , shardIds , aliasFilters , listener ) -> {
184185 sent .add (new NodeRequest (node , shardIds , aliasFilters ));
185186 Map <ShardId , Exception > failures = new HashMap <>();
186187 if (shardIds .contains (shard5 )) {
@@ -206,7 +207,7 @@ public void testDoNotRetryOnRequestLevelFailure() {
206207 var targetShards = List .of (targetShard (shard1 , node1 ), targetShard (shard2 , node2 ), targetShard (shard3 , node1 ));
207208 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
208209 AtomicBoolean failed = new AtomicBoolean ();
209- var future = sendRequests (targetShards , false , (node , shardIds , aliasFilters , listener ) -> {
210+ var future = sendRequests (targetShards , false , 10 , (node , shardIds , aliasFilters , listener ) -> {
210211 sent .add (new NodeRequest (node , shardIds , aliasFilters ));
211212 if (node1 .equals (node ) && failed .compareAndSet (false , true )) {
212213 runWithDelay (() -> listener .onFailure (new IOException ("test request level failure" ), true ));
@@ -226,7 +227,7 @@ public void testAllowPartialResults() {
226227 var targetShards = List .of (targetShard (shard1 , node1 ), targetShard (shard2 , node2 ), targetShard (shard3 , node1 , node2 ));
227228 Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
228229 AtomicBoolean failed = new AtomicBoolean ();
229- var future = sendRequests (targetShards , true , (node , shardIds , aliasFilters , listener ) -> {
230+ var future = sendRequests (targetShards , true , 10 , (node , shardIds , aliasFilters , listener ) -> {
230231 sent .add (new NodeRequest (node , shardIds , aliasFilters ));
231232 if (node1 .equals (node ) && failed .compareAndSet (false , true )) {
232233 runWithDelay (() -> listener .onFailure (new IOException ("test request level failure" ), true ));
@@ -244,6 +245,40 @@ public void testAllowPartialResults() {
244245 assertThat (resp .successfulShards , equalTo (1 ));
245246 }
246247
248+ public void testLimitConcurrentNodes () {
249+ var targetShards = List .of (
250+ targetShard (shard1 , node1 ),
251+ targetShard (shard2 , node2 ),
252+ targetShard (shard3 , node3 ),
253+ targetShard (shard4 , node4 ),
254+ targetShard (shard5 , node5 )
255+ );
256+
257+ AtomicInteger maxConcurrentRequests = new AtomicInteger (0 );
258+ AtomicInteger concurrentRequests = new AtomicInteger (0 );
259+ Queue <NodeRequest > sent = ConcurrentCollections .newQueue ();
260+ var future = sendRequests (targetShards , randomBoolean (), 2 , (node , shardIds , aliasFilters , listener ) -> {
261+ concurrentRequests .incrementAndGet ();
262+
263+ while (true ) {
264+ var priorMax = maxConcurrentRequests .get ();
265+ var newMax = Math .max (priorMax , concurrentRequests .get ());
266+ if (newMax <= priorMax || maxConcurrentRequests .compareAndSet (priorMax , newMax )) {
267+ break ;
268+ }
269+ }
270+
271+ sent .add (new NodeRequest (node , shardIds , aliasFilters ));
272+ runWithDelay (() -> {
273+ concurrentRequests .decrementAndGet ();
274+ listener .onResponse (new DataNodeComputeResponse (List .of (), Map .of ()));
275+ });
276+ });
277+ safeGet (future );
278+ assertThat (sent .size (), equalTo (5 ));
279+ assertThat (maxConcurrentRequests .get (), equalTo (2 ));
280+ }
281+
247282 static DataNodeRequestSender .TargetShard targetShard (ShardId shardId , DiscoveryNode ... nodes ) {
248283 return new DataNodeRequestSender .TargetShard (shardId , new ArrayList <>(Arrays .asList (nodes )), null );
249284 }
@@ -268,6 +303,7 @@ void runWithDelay(Runnable runnable) {
268303 PlainActionFuture <ComputeResponse > sendRequests (
269304 List <DataNodeRequestSender .TargetShard > shards ,
270305 boolean allowPartialResults ,
306+ int concurrentRequests ,
271307 Sender sender
272308 ) {
273309 PlainActionFuture <ComputeResponse > future = new PlainActionFuture <>();
@@ -281,7 +317,13 @@ PlainActionFuture<ComputeResponse> sendRequests(
281317 TaskId .EMPTY_TASK_ID ,
282318 Collections .emptyMap ()
283319 );
284- DataNodeRequestSender requestSender = new DataNodeRequestSender (transportService , executor , task , allowPartialResults ) {
320+ DataNodeRequestSender requestSender = new DataNodeRequestSender (
321+ transportService ,
322+ executor ,
323+ task ,
324+ allowPartialResults ,
325+ concurrentRequests
326+ ) {
285327 @ Override
286328 void searchShards (
287329 Task parentTask ,
0 commit comments