diff --git a/docs/changelog/130303.yaml b/docs/changelog/130303.yaml new file mode 100644 index 0000000000000..aff277f67eba1 --- /dev/null +++ b/docs/changelog/130303.yaml @@ -0,0 +1,5 @@ +pr: 130303 +summary: Drain responses on completion for `TransportNodesAction` +area: Distributed +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java index 1457aa4dce82d..2eb5820a20ecb 100644 --- a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java @@ -46,6 +46,7 @@ import java.util.List; import java.util.Objects; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.core.Strings.format; @@ -103,6 +104,7 @@ protected void doExecute(Task task, NodesRequest request, ActionListener responses = new ArrayList<>(concreteNodes.length); final ArrayList exceptions = new ArrayList<>(0); + final AtomicBoolean responsesHandled = new AtomicBoolean(false); final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout()); @@ -113,12 +115,14 @@ protected void doExecute(Task task, NodesRequest request, ActionListener { - final List drainedResponses; - synchronized (responses) { - drainedResponses = List.copyOf(responses); - responses.clear(); + if (responsesHandled.compareAndSet(false, true)) { + final List drainedResponses; + synchronized (responses) { + drainedResponses = List.copyOf(responses); + responses.clear(); + } + Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close(); } - Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close(); }); } } @@ -165,10 +169,18 @@ protected void onItemFailure(DiscoveryNode discoveryNode, Exception e) { @Override protected CheckedConsumer, Exception> onCompletion() { - // ref releases all happen-before here so no need to be synchronized return l -> { - try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) { - newResponseAsync(task, request, actionContext, responses, exceptions, l); + if (responsesHandled.compareAndSet(false, true)) { + // ref releases all happen-before here so no need to be synchronized + try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) { + newResponseAsync(task, request, actionContext, responses, exceptions, l); + } + } else { + logger.debug("task cancelled after all responses were collected"); + assert task instanceof CancellableTask : "expect CancellableTask, but got: " + task; + final var cancellableTask = (CancellableTask) task; + assert cancellableTask.isCancelled(); + cancellableTask.notifyIfCancelled(l); } }; } diff --git a/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java b/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java index 4a3b060c3e1c0..6ea4a86fae42e 100644 --- a/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.action.support.nodes; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.ActionFilters; @@ -57,6 +58,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -66,7 +69,9 @@ import static java.util.Collections.emptyMap; import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.elasticsearch.test.ClusterServiceUtils.setState; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.mock; public class TransportNodesActionTests extends ESTestCase { @@ -316,6 +321,137 @@ protected Object createActionContext(Task task, TestNodesRequest request) { assertTrue(cancellableTask.isCancelled()); // keep task alive } + public void testCompletionShouldNotBeInterferedByCancellationAfterProcessingBegins() throws Exception { + final var barrier = new CyclicBarrier(2); + final var action = new TestTransportNodesAction( + clusterService, + transportService, + new ActionFilters(Set.of()), + TestNodeRequest::new, + THREAD_POOL.executor(ThreadPool.Names.GENERIC) + ) { + @Override + protected void newResponseAsync( + Task task, + TestNodesRequest request, + Void unused, + List testNodeResponses, + List failures, + ActionListener listener + ) { + boolean waited = false; + // Process node responses in a loop and ensure no ConcurrentModificationException will be thrown due to + // concurrent cancellation coming after the loop has started, see also #128852 + for (var response : testNodeResponses) { + if (waited == false) { + waited = true; + safeAwait(barrier); + safeAwait(barrier); + } + } + super.newResponseAsync(task, request, unused, testNodeResponses, failures, listener); + } + }; + + final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()); + final var cancelledFuture = new PlainActionFuture(); + cancellableTask.addListener(() -> cancelledFuture.onResponse(null)); + + final PlainActionFuture future = new PlainActionFuture<>(); + action.execute(cancellableTask, new TestNodesRequest(), future); + + for (var capturedRequest : transport.getCapturedRequestsAndClear()) { + completeOneRequest(capturedRequest); + } + + // Wait for the overall response to start processing the node responses in a loop and then cancel the task. + // The cancellation should not interfere with the node response processing. + safeAwait(barrier); + TaskCancelHelper.cancel(cancellableTask, "simulated"); + safeGet(cancelledFuture); + + // Let the process continue, and it should be successful + safeAwait(barrier); + assertResponseReleased(safeGet(future)); + } + + public void testConcurrentlyCompletionAndCancellation() throws InterruptedException { + final var action = getTestTransportNodesAction(); + + final CountDownLatch onCancelledLatch = new CountDownLatch(1); + final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()) { + @Override + protected void onCancelled() { + onCancelledLatch.countDown(); + } + }; + + final PlainActionFuture future = new PlainActionFuture<>(); + action.execute(cancellableTask, new TestNodesRequest(), future); + + final List nodeResponses = new ArrayList<>(); + final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear(); + for (int i = 0; i < capturedRequests.length - 1; i++) { + final var capturedRequest = capturedRequests[i]; + nodeResponses.add(completeOneRequest(capturedRequest)); + } + + final var raceBarrier = new CyclicBarrier(3); + final Thread completeThread = new Thread(() -> { + safeAwait(raceBarrier); + nodeResponses.add(completeOneRequest(capturedRequests[capturedRequests.length - 1])); + }); + final Thread cancelThread = new Thread(() -> { + safeAwait(raceBarrier); + TaskCancelHelper.cancel(cancellableTask, "simulated"); + }); + completeThread.start(); + cancelThread.start(); + safeAwait(raceBarrier); + + // We expect either a successful response or a cancellation exception. All node responses should be released in both cases. + try { + final var testNodesResponse = future.actionGet(SAFE_AWAIT_TIMEOUT); + assertThat(testNodesResponse.getNodes(), hasSize(capturedRequests.length)); + assertResponseReleased(testNodesResponse); + } catch (Exception e) { + final var taskCancelledException = (TaskCancelledException) ExceptionsHelper.unwrap(e, TaskCancelledException.class); + assertNotNull("expect task cancellation exception, but got\n" + ExceptionsHelper.stackTrace(e), taskCancelledException); + assertThat(e.getMessage(), containsString("task cancelled [simulated]")); + assertTrue(cancellableTask.isCancelled()); + safeAwait(onCancelledLatch); // wait for the latch, the listener for releasing node responses is called before it + assertTrue(nodeResponses.stream().allMatch(r -> r.hasReferences() == false)); + } + + completeThread.join(10_000); + cancelThread.join(10_000); + assertFalse(completeThread.isAlive()); + assertFalse(cancelThread.isAlive()); + } + + private void assertResponseReleased(TestNodesResponse response) { + final var allResponsesReleasedListener = new SubscribableListener(); + try (var listeners = new RefCountingListener(allResponsesReleasedListener)) { + response.addCloseListener(listeners.acquire()); + for (final var nodeResponse : response.getNodes()) { + nodeResponse.addCloseListener(listeners.acquire()); + } + } + safeAwait(allResponsesReleasedListener); + assertTrue(response.getNodes().stream().noneMatch(TestNodeResponse::hasReferences)); + assertFalse(response.hasReferences()); + } + + private TestNodeResponse completeOneRequest(CapturingTransport.CapturedRequest capturedRequest) { + final var response = new TestNodeResponse(capturedRequest.node()); + try { + transport.getTransportResponseHandler(capturedRequest.requestId()).handleResponse(response); + } finally { + response.decRef(); + } + return response; + } + @BeforeClass public static void startThreadPool() { THREAD_POOL = new TestThreadPool(TransportNodesActionTests.class.getSimpleName());