2121import org .elasticsearch .action .support .PlainActionFuture ;
2222import org .elasticsearch .client .internal .node .NodeClient ;
2323import org .elasticsearch .common .settings .Settings ;
24+ import org .elasticsearch .common .util .set .Sets ;
2425import org .elasticsearch .http .HttpChannel ;
2526import org .elasticsearch .http .HttpResponse ;
2627import org .elasticsearch .tasks .Task ;
4445import java .util .concurrent .atomic .AtomicInteger ;
4546import java .util .concurrent .atomic .AtomicLong ;
4647import java .util .concurrent .atomic .AtomicReference ;
48+ import java .util .function .LongSupplier ;
4749
4850public class RestCancellableNodeClientTests extends ESTestCase {
4951
@@ -148,8 +150,42 @@ public void testChannelAlreadyClosed() {
148150 assertEquals (totalSearches , testClient .cancelledTasks .size ());
149151 }
150152
153+ public void testConcurrentExecuteAndClose () throws Exception {
154+ final var testClient = new TestClient (Settings .EMPTY , threadPool , true );
155+ int initialHttpChannels = RestCancellableNodeClient .getNumChannels ();
156+ int numTasks = randomIntBetween (1 , 30 );
157+ TestHttpChannel channel = new TestHttpChannel ();
158+ final var startLatch = new CountDownLatch (1 );
159+ final var doneLatch = new CountDownLatch (numTasks + 1 );
160+ final var expectedTasks = Sets .<TaskId >newHashSetWithExpectedSize (numTasks );
161+ for (int j = 0 ; j < numTasks ; j ++) {
162+ RestCancellableNodeClient client = new RestCancellableNodeClient (testClient , channel );
163+ threadPool .generic ().execute (() -> {
164+ client .execute (TransportSearchAction .TYPE , new SearchRequest (), ActionListener .running (ESTestCase ::fail ));
165+ startLatch .countDown ();
166+ doneLatch .countDown ();
167+ });
168+ expectedTasks .add (new TaskId (testClient .getLocalNodeId (), j ));
169+ }
170+ threadPool .generic ().execute (() -> {
171+ try {
172+ safeAwait (startLatch );
173+ channel .awaitClose ();
174+ } catch (InterruptedException e ) {
175+ Thread .currentThread ().interrupt ();
176+ throw new AssertionError (e );
177+ } finally {
178+ doneLatch .countDown ();
179+ }
180+ });
181+ safeAwait (doneLatch );
182+ assertEquals (initialHttpChannels , RestCancellableNodeClient .getNumChannels ());
183+ assertEquals (expectedTasks , testClient .cancelledTasks );
184+ }
185+
151186 private static class TestClient extends NodeClient {
152- private final AtomicLong counter = new AtomicLong (0 );
187+ private final LongSupplier searchTaskIdGenerator = new AtomicLong (0 )::getAndIncrement ;
188+ private final LongSupplier cancelTaskIdGenerator = new AtomicLong (1000 )::getAndIncrement ;
153189 private final Set <TaskId > cancelledTasks = new CopyOnWriteArraySet <>();
154190 private final AtomicInteger searchRequests = new AtomicInteger (0 );
155191 private final boolean timeout ;
@@ -167,9 +203,17 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
167203 ) {
168204 switch (action .name ()) {
169205 case TransportCancelTasksAction .NAME -> {
170- CancelTasksRequest cancelTasksRequest = (CancelTasksRequest ) request ;
171- assertTrue ("tried to cancel the same task more than once" , cancelledTasks .add (cancelTasksRequest .getTargetTaskId ()));
172- Task task = request .createTask (counter .getAndIncrement (), "cancel_task" , action .name (), null , Collections .emptyMap ());
206+ assertTrue (
207+ "tried to cancel the same task more than once" ,
208+ cancelledTasks .add (asInstanceOf (CancelTasksRequest .class , request ).getTargetTaskId ())
209+ );
210+ Task task = request .createTask (
211+ cancelTaskIdGenerator .getAsLong (),
212+ "cancel_task" ,
213+ action .name (),
214+ null ,
215+ Collections .emptyMap ()
216+ );
173217 if (randomBoolean ()) {
174218 listener .onResponse (null );
175219 } else {
@@ -180,7 +224,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
180224 }
181225 case TransportSearchAction .NAME -> {
182226 searchRequests .incrementAndGet ();
183- Task searchTask = request .createTask (counter .getAndIncrement (), "search" , action .name (), null , Collections .emptyMap ());
227+ Task searchTask = request .createTask (
228+ searchTaskIdGenerator .getAsLong (),
229+ "search" ,
230+ action .name (),
231+ null ,
232+ Collections .emptyMap ()
233+ );
184234 if (timeout == false ) {
185235 if (rarely ()) {
186236 // make sure that search is sometimes also called from the same thread before the task is returned
@@ -191,7 +241,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
191241 }
192242 return searchTask ;
193243 }
194- default -> throw new UnsupportedOperationException ( );
244+ default -> throw new AssertionError ( "unexpected action " + action . name () );
195245 }
196246
197247 }
@@ -222,10 +272,7 @@ public InetSocketAddress getRemoteAddress() {
222272
223273 @ Override
224274 public void close () {
225- if (open .compareAndSet (true , false ) == false ) {
226- assert false : "HttpChannel is already closed" ;
227- return ; // nothing to do
228- }
275+ assertTrue ("HttpChannel is already closed" , open .compareAndSet (true , false ));
229276 ActionListener <Void > listener = closeListener .get ();
230277 if (listener != null ) {
231278 boolean failure = randomBoolean ();
@@ -241,6 +288,7 @@ public void close() {
241288 }
242289
243290 private void awaitClose () throws InterruptedException {
291+ assertNotNull ("must set closeListener before calling awaitClose" , closeListener .get ());
244292 close ();
245293 closeLatch .await ();
246294 }
@@ -257,7 +305,7 @@ public void addCloseListener(ActionListener<Void> listener) {
257305 listener .onResponse (null );
258306 } else {
259307 if (closeListener .compareAndSet (null , listener ) == false ) {
260- throw new IllegalStateException ("close listener already set, only one is allowed!" );
308+ throw new AssertionError ("close listener already set, only one is allowed!" );
261309 }
262310 }
263311 }
0 commit comments