|  | 
| 10 | 10 | package org.elasticsearch.action.support.nodes; | 
| 11 | 11 | 
 | 
| 12 | 12 | import org.elasticsearch.ElasticsearchException; | 
|  | 13 | +import org.elasticsearch.ExceptionsHelper; | 
| 13 | 14 | import org.elasticsearch.action.ActionListener; | 
| 14 | 15 | import org.elasticsearch.action.FailedNodeException; | 
| 15 | 16 | import org.elasticsearch.action.support.ActionFilters; | 
|  | 
| 57 | 58 | import java.util.List; | 
| 58 | 59 | import java.util.Map; | 
| 59 | 60 | import java.util.Set; | 
|  | 61 | +import java.util.concurrent.CountDownLatch; | 
|  | 62 | +import java.util.concurrent.CyclicBarrier; | 
| 60 | 63 | import java.util.concurrent.Executor; | 
| 61 | 64 | import java.util.concurrent.TimeUnit; | 
| 62 | 65 | import java.util.concurrent.atomic.AtomicInteger; | 
|  | 
| 66 | 69 | import static java.util.Collections.emptyMap; | 
| 67 | 70 | import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; | 
| 68 | 71 | import static org.elasticsearch.test.ClusterServiceUtils.setState; | 
|  | 72 | +import static org.hamcrest.Matchers.containsString; | 
| 69 | 73 | import static org.hamcrest.Matchers.greaterThan; | 
|  | 74 | +import static org.hamcrest.Matchers.hasSize; | 
| 70 | 75 | import static org.mockito.Mockito.mock; | 
| 71 | 76 | 
 | 
| 72 | 77 | public class TransportNodesActionTests extends ESTestCase { | 
| @@ -316,6 +321,137 @@ protected Object createActionContext(Task task, TestNodesRequest request) { | 
| 316 | 321 |         assertTrue(cancellableTask.isCancelled()); // keep task alive | 
| 317 | 322 |     } | 
| 318 | 323 | 
 | 
|  | 324 | +    public void testCompletionShouldNotBeInterferedByCancellationAfterProcessingBegins() throws Exception { | 
|  | 325 | +        final var barrier = new CyclicBarrier(2); | 
|  | 326 | +        final var action = new TestTransportNodesAction( | 
|  | 327 | +            clusterService, | 
|  | 328 | +            transportService, | 
|  | 329 | +            new ActionFilters(Set.of()), | 
|  | 330 | +            TestNodeRequest::new, | 
|  | 331 | +            THREAD_POOL.executor(ThreadPool.Names.GENERIC) | 
|  | 332 | +        ) { | 
|  | 333 | +            @Override | 
|  | 334 | +            protected void newResponseAsync( | 
|  | 335 | +                Task task, | 
|  | 336 | +                TestNodesRequest request, | 
|  | 337 | +                Void unused, | 
|  | 338 | +                List<TestNodeResponse> testNodeResponses, | 
|  | 339 | +                List<FailedNodeException> failures, | 
|  | 340 | +                ActionListener<TestNodesResponse> listener | 
|  | 341 | +            ) { | 
|  | 342 | +                boolean waited = false; | 
|  | 343 | +                // Process node responses in a loop and ensure no ConcurrentModificationException will be thrown due to | 
|  | 344 | +                // concurrent cancellation coming after the loop has started, see also #128852 | 
|  | 345 | +                for (var response : testNodeResponses) { | 
|  | 346 | +                    if (waited == false) { | 
|  | 347 | +                        waited = true; | 
|  | 348 | +                        safeAwait(barrier); | 
|  | 349 | +                        safeAwait(barrier); | 
|  | 350 | +                    } | 
|  | 351 | +                } | 
|  | 352 | +                super.newResponseAsync(task, request, unused, testNodeResponses, failures, listener); | 
|  | 353 | +            } | 
|  | 354 | +        }; | 
|  | 355 | + | 
|  | 356 | +        final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()); | 
|  | 357 | +        final var cancelledFuture = new PlainActionFuture<Void>(); | 
|  | 358 | +        cancellableTask.addListener(() -> cancelledFuture.onResponse(null)); | 
|  | 359 | + | 
|  | 360 | +        final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>(); | 
|  | 361 | +        action.execute(cancellableTask, new TestNodesRequest(), future); | 
|  | 362 | + | 
|  | 363 | +        for (var capturedRequest : transport.getCapturedRequestsAndClear()) { | 
|  | 364 | +            completeOneRequest(capturedRequest); | 
|  | 365 | +        } | 
|  | 366 | + | 
|  | 367 | +        // Wait for the overall response to start processing the node responses in a loop and then cancel the task. | 
|  | 368 | +        // The cancellation should not interfere with the node response processing. | 
|  | 369 | +        safeAwait(barrier); | 
|  | 370 | +        TaskCancelHelper.cancel(cancellableTask, "simulated"); | 
|  | 371 | +        safeGet(cancelledFuture); | 
|  | 372 | + | 
|  | 373 | +        // Let the process continue, and it should be successful | 
|  | 374 | +        safeAwait(barrier); | 
|  | 375 | +        assertResponseReleased(safeGet(future)); | 
|  | 376 | +    } | 
|  | 377 | + | 
|  | 378 | +    public void testConcurrentlyCompletionAndCancellation() throws InterruptedException { | 
|  | 379 | +        final var action = getTestTransportNodesAction(); | 
|  | 380 | + | 
|  | 381 | +        final CountDownLatch onCancelledLatch = new CountDownLatch(1); | 
|  | 382 | +        final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()) { | 
|  | 383 | +            @Override | 
|  | 384 | +            protected void onCancelled() { | 
|  | 385 | +                onCancelledLatch.countDown(); | 
|  | 386 | +            } | 
|  | 387 | +        }; | 
|  | 388 | + | 
|  | 389 | +        final PlainActionFuture<TestNodesResponse> future = new PlainActionFuture<>(); | 
|  | 390 | +        action.execute(cancellableTask, new TestNodesRequest(), future); | 
|  | 391 | + | 
|  | 392 | +        final List<TestNodeResponse> nodeResponses = new ArrayList<>(); | 
|  | 393 | +        final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear(); | 
|  | 394 | +        for (int i = 0; i < capturedRequests.length - 1; i++) { | 
|  | 395 | +            final var capturedRequest = capturedRequests[i]; | 
|  | 396 | +            nodeResponses.add(completeOneRequest(capturedRequest)); | 
|  | 397 | +        } | 
|  | 398 | + | 
|  | 399 | +        final var raceBarrier = new CyclicBarrier(3); | 
|  | 400 | +        final Thread completeThread = new Thread(() -> { | 
|  | 401 | +            safeAwait(raceBarrier); | 
|  | 402 | +            nodeResponses.add(completeOneRequest(capturedRequests[capturedRequests.length - 1])); | 
|  | 403 | +        }); | 
|  | 404 | +        final Thread cancelThread = new Thread(() -> { | 
|  | 405 | +            safeAwait(raceBarrier); | 
|  | 406 | +            TaskCancelHelper.cancel(cancellableTask, "simulated"); | 
|  | 407 | +        }); | 
|  | 408 | +        completeThread.start(); | 
|  | 409 | +        cancelThread.start(); | 
|  | 410 | +        safeAwait(raceBarrier); | 
|  | 411 | + | 
|  | 412 | +        // We expect either a successful response or a cancellation exception. All node responses should be released in both cases. | 
|  | 413 | +        try { | 
|  | 414 | +            final var testNodesResponse = future.actionGet(SAFE_AWAIT_TIMEOUT); | 
|  | 415 | +            assertThat(testNodesResponse.getNodes(), hasSize(capturedRequests.length)); | 
|  | 416 | +            assertResponseReleased(testNodesResponse); | 
|  | 417 | +        } catch (Exception e) { | 
|  | 418 | +            final var taskCancelledException = (TaskCancelledException) ExceptionsHelper.unwrap(e, TaskCancelledException.class); | 
|  | 419 | +            assertNotNull("expect task cancellation exception, but got\n" + ExceptionsHelper.stackTrace(e), taskCancelledException); | 
|  | 420 | +            assertThat(e.getMessage(), containsString("task cancelled [simulated]")); | 
|  | 421 | +            assertTrue(cancellableTask.isCancelled()); | 
|  | 422 | +            safeAwait(onCancelledLatch); // wait for the latch, the listener for releasing node responses is called before it | 
|  | 423 | +            assertTrue(nodeResponses.stream().allMatch(r -> r.hasReferences() == false)); | 
|  | 424 | +        } | 
|  | 425 | + | 
|  | 426 | +        completeThread.join(10_000); | 
|  | 427 | +        cancelThread.join(10_000); | 
|  | 428 | +        assertFalse(completeThread.isAlive()); | 
|  | 429 | +        assertFalse(cancelThread.isAlive()); | 
|  | 430 | +    } | 
|  | 431 | + | 
|  | 432 | +    private void assertResponseReleased(TestNodesResponse response) { | 
|  | 433 | +        final var allResponsesReleasedListener = new SubscribableListener<Void>(); | 
|  | 434 | +        try (var listeners = new RefCountingListener(allResponsesReleasedListener)) { | 
|  | 435 | +            response.addCloseListener(listeners.acquire()); | 
|  | 436 | +            for (final var nodeResponse : response.getNodes()) { | 
|  | 437 | +                nodeResponse.addCloseListener(listeners.acquire()); | 
|  | 438 | +            } | 
|  | 439 | +        } | 
|  | 440 | +        safeAwait(allResponsesReleasedListener); | 
|  | 441 | +        assertTrue(response.getNodes().stream().noneMatch(TestNodeResponse::hasReferences)); | 
|  | 442 | +        assertFalse(response.hasReferences()); | 
|  | 443 | +    } | 
|  | 444 | + | 
|  | 445 | +    private TestNodeResponse completeOneRequest(CapturingTransport.CapturedRequest capturedRequest) { | 
|  | 446 | +        final var response = new TestNodeResponse(capturedRequest.node()); | 
|  | 447 | +        try { | 
|  | 448 | +            transport.getTransportResponseHandler(capturedRequest.requestId()).handleResponse(response); | 
|  | 449 | +        } finally { | 
|  | 450 | +            response.decRef(); | 
|  | 451 | +        } | 
|  | 452 | +        return response; | 
|  | 453 | +    } | 
|  | 454 | + | 
| 319 | 455 |     @BeforeClass | 
| 320 | 456 |     public static void startThreadPool() { | 
| 321 | 457 |         THREAD_POOL = new TestThreadPool(TransportNodesActionTests.class.getSimpleName()); | 
|  | 
0 commit comments