| 
10 | 10 | package org.elasticsearch.action.support.nodes;  | 
11 | 11 | 
 
  | 
12 | 12 | import org.elasticsearch.ElasticsearchException;  | 
 | 13 | +import org.elasticsearch.ExceptionsHelper;  | 
13 | 14 | import org.elasticsearch.action.ActionListener;  | 
14 | 15 | import org.elasticsearch.action.FailedNodeException;  | 
15 | 16 | import org.elasticsearch.action.support.ActionFilters;  | 
 | 
57 | 58 | import java.util.List;  | 
58 | 59 | import java.util.Map;  | 
59 | 60 | import java.util.Set;  | 
 | 61 | +import java.util.concurrent.CountDownLatch;  | 
 | 62 | +import java.util.concurrent.CyclicBarrier;  | 
60 | 63 | import java.util.concurrent.Executor;  | 
61 | 64 | import java.util.concurrent.TimeUnit;  | 
62 | 65 | import java.util.concurrent.atomic.AtomicInteger;  | 
 | 
66 | 69 | import static java.util.Collections.emptyMap;  | 
67 | 70 | import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;  | 
68 | 71 | import static org.elasticsearch.test.ClusterServiceUtils.setState;  | 
 | 72 | +import static org.hamcrest.Matchers.containsString;  | 
69 | 73 | import static org.hamcrest.Matchers.greaterThan;  | 
 | 74 | +import static org.hamcrest.Matchers.hasSize;  | 
70 | 75 | import static org.mockito.Mockito.mock;  | 
71 | 76 | 
 
  | 
72 | 77 | public class TransportNodesActionTests extends ESTestCase {  | 
@@ -316,6 +321,137 @@ protected Object createActionContext(Task task, TestNodesRequest request) {  | 
316 | 321 |         assertTrue(cancellableTask.isCancelled()); // keep task alive  | 
317 | 322 |     }  | 
318 | 323 | 
 
  | 
 | 324 | +    public void testCompletionShouldNotBeInterferedByCancellationAfterProcessingBegins() throws Exception {  | 
 | 325 | +        final var barrier = new CyclicBarrier(2);  | 
 | 326 | +        final var action = new TestTransportNodesAction(  | 
 | 327 | +            clusterService,  | 
 | 328 | +            transportService,  | 
 | 329 | +            new ActionFilters(Set.of()),  | 
 | 330 | +            TestNodeRequest::new,  | 
 | 331 | +            THREAD_POOL.executor(ThreadPool.Names.GENERIC)  | 
 | 332 | +        ) {  | 
 | 333 | +            @Override  | 
 | 334 | +            protected void newResponseAsync(  | 
 | 335 | +                Task task,  | 
 | 336 | +                TestNodesRequest request,  | 
 | 337 | +                Void unused,  | 
 | 338 | +                List<TestNodeResponse> testNodeResponses,  | 
 | 339 | +                List<FailedNodeException> failures,  | 
 | 340 | +                ActionListener<TestNodesResponse> listener  | 
 | 341 | +            ) {  | 
 | 342 | +                boolean waited = false;  | 
 | 343 | +                // Process node responses in a loop and ensure no ConcurrentModificationException will be thrown due to  | 
 | 344 | +                // concurrent cancellation coming after the loop has started, see also #128852  | 
 | 345 | +                for (var response : testNodeResponses) {  | 
 | 346 | +                    if (waited == false) {  | 
 | 347 | +                        waited = true;  | 
 | 348 | +                        safeAwait(barrier);  | 
 | 349 | +                        safeAwait(barrier);  | 
 | 350 | +                    }  | 
 | 351 | +                }  | 
 | 352 | +                super.newResponseAsync(task, request, unused, testNodeResponses, failures, listener);  | 
 | 353 | +            }  | 
 | 354 | +        };  | 
 | 355 | + | 
 | 356 | +        final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());  | 
 | 357 | +        final var cancelledFuture = new PlainActionFuture<Void>();  | 
 | 358 | +        cancellableTask.addListener(() -> cancelledFuture.onResponse(null));  | 
 | 359 | + | 
 | 360 | +        final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();  | 
 | 361 | +        action.execute(cancellableTask, new TestNodesRequest(), future);  | 
 | 362 | + | 
 | 363 | +        for (var capturedRequest : transport.getCapturedRequestsAndClear()) {  | 
 | 364 | +            completeOneRequest(capturedRequest);  | 
 | 365 | +        }  | 
 | 366 | + | 
 | 367 | +        // Wait for the overall response to start processing the node responses in a loop and then cancel the task.  | 
 | 368 | +        // The cancellation should not interfere with the node response processing.  | 
 | 369 | +        safeAwait(barrier);  | 
 | 370 | +        TaskCancelHelper.cancel(cancellableTask, "simulated");  | 
 | 371 | +        safeGet(cancelledFuture);  | 
 | 372 | + | 
 | 373 | +        // Let the process continue, and it should be successful  | 
 | 374 | +        safeAwait(barrier);  | 
 | 375 | +        assertResponseReleased(safeGet(future));  | 
 | 376 | +    }  | 
 | 377 | + | 
 | 378 | +    public void testConcurrentlyCompletionAndCancellation() throws InterruptedException {  | 
 | 379 | +        final var action = getTestTransportNodesAction();  | 
 | 380 | + | 
 | 381 | +        final CountDownLatch onCancelledLatch = new CountDownLatch(1);  | 
 | 382 | +        final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()) {  | 
 | 383 | +            @Override  | 
 | 384 | +            protected void onCancelled() {  | 
 | 385 | +                onCancelledLatch.countDown();  | 
 | 386 | +            }  | 
 | 387 | +        };  | 
 | 388 | + | 
 | 389 | +        final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>();  | 
 | 390 | +        action.execute(cancellableTask, new TestNodesRequest(), future);  | 
 | 391 | + | 
 | 392 | +        final List<TestNodeResponse> nodeResponses = new ArrayList<>();  | 
 | 393 | +        final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();  | 
 | 394 | +        for (int i = 0; i < capturedRequests.length - 1; i++) {  | 
 | 395 | +            final var capturedRequest = capturedRequests[i];  | 
 | 396 | +            nodeResponses.add(completeOneRequest(capturedRequest));  | 
 | 397 | +        }  | 
 | 398 | + | 
 | 399 | +        final var raceBarrier = new CyclicBarrier(3);  | 
 | 400 | +        final Thread completeThread = new Thread(() -> {  | 
 | 401 | +            safeAwait(raceBarrier);  | 
 | 402 | +            nodeResponses.add(completeOneRequest(capturedRequests[capturedRequests.length - 1]));  | 
 | 403 | +        });  | 
 | 404 | +        final Thread cancelThread = new Thread(() -> {  | 
 | 405 | +            safeAwait(raceBarrier);  | 
 | 406 | +            TaskCancelHelper.cancel(cancellableTask, "simulated");  | 
 | 407 | +        });  | 
 | 408 | +        completeThread.start();  | 
 | 409 | +        cancelThread.start();  | 
 | 410 | +        safeAwait(raceBarrier);  | 
 | 411 | + | 
 | 412 | +        // We expect either a successful response or a cancellation exception. All node responses should be released in both cases.  | 
 | 413 | +        try {  | 
 | 414 | +            final var testNodesResponse = future.actionGet(SAFE_AWAIT_TIMEOUT);  | 
 | 415 | +            assertThat(testNodesResponse.getNodes(), hasSize(capturedRequests.length));  | 
 | 416 | +            assertResponseReleased(testNodesResponse);  | 
 | 417 | +        } catch (Exception e) {  | 
 | 418 | +            final var taskCancelledException = (TaskCancelledException) ExceptionsHelper.unwrap(e, TaskCancelledException.class);  | 
 | 419 | +            assertNotNull("expect task cancellation exception, but got\n" + ExceptionsHelper.stackTrace(e), taskCancelledException);  | 
 | 420 | +            assertThat(e.getMessage(), containsString("task cancelled [simulated]"));  | 
 | 421 | +            assertTrue(cancellableTask.isCancelled());  | 
 | 422 | +            safeAwait(onCancelledLatch); // wait for the latch, the listener for releasing node responses is called before it  | 
 | 423 | +            assertTrue(nodeResponses.stream().allMatch(r -> r.hasReferences() == false));  | 
 | 424 | +        }  | 
 | 425 | + | 
 | 426 | +        completeThread.join(10_000);  | 
 | 427 | +        cancelThread.join(10_000);  | 
 | 428 | +        assertFalse(completeThread.isAlive());  | 
 | 429 | +        assertFalse(cancelThread.isAlive());  | 
 | 430 | +    }  | 
 | 431 | + | 
 | 432 | +    private void assertResponseReleased(TestNodesResponse response) {  | 
 | 433 | +        final var allResponsesReleasedListener = new SubscribableListener<Void>();  | 
 | 434 | +        try (var listeners = new RefCountingListener(allResponsesReleasedListener)) {  | 
 | 435 | +            response.addCloseListener(listeners.acquire());  | 
 | 436 | +            for (final var nodeResponse : response.getNodes()) {  | 
 | 437 | +                nodeResponse.addCloseListener(listeners.acquire());  | 
 | 438 | +            }  | 
 | 439 | +        }  | 
 | 440 | +        safeAwait(allResponsesReleasedListener);  | 
 | 441 | +        assertTrue(response.getNodes().stream().noneMatch(TestNodeResponse::hasReferences));  | 
 | 442 | +        assertFalse(response.hasReferences());  | 
 | 443 | +    }  | 
 | 444 | + | 
 | 445 | +    private TestNodeResponse completeOneRequest(CapturingTransport.CapturedRequest capturedRequest) {  | 
 | 446 | +        final var response = new TestNodeResponse(capturedRequest.node());  | 
 | 447 | +        try {  | 
 | 448 | +            transport.getTransportResponseHandler(capturedRequest.requestId()).handleResponse(response);  | 
 | 449 | +        } finally {  | 
 | 450 | +            response.decRef();  | 
 | 451 | +        }  | 
 | 452 | +        return response;  | 
 | 453 | +    }  | 
 | 454 | + | 
319 | 455 |     @BeforeClass  | 
320 | 456 |     public static void startThreadPool() {  | 
321 | 457 |         THREAD_POOL = new TestThreadPool(TransportNodesActionTests.class.getSimpleName());  | 
 | 
0 commit comments