diff --git a/docs/changelog/96279.yaml b/docs/changelog/96279.yaml new file mode 100644 index 0000000000000..39c14d64e34a4 --- /dev/null +++ b/docs/changelog/96279.yaml @@ -0,0 +1,5 @@ +pr: 96279 +summary: Improve cancellability in `TransportTasksAction` +area: Task Management +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java b/server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java index daedacf6fb4ad..4c563b95449e7 100644 --- a/server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java @@ -10,40 +10,36 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.NoSuchNodeException; import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.CancellableFanOut; import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.concurrent.AtomicArray; -import org.elasticsearch.core.Tuple; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.transport.TransportChannel; -import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; -import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReferenceArray; - -import static java.util.Collections.emptyList; /** * The base class for transport actions that are interacting with currently running tasks. @@ -85,67 +81,113 @@ protected TransportTasksAction( @Override protected void doExecute(Task task, TasksRequest request, ActionListener listener) { - new AsyncAction(task, request, listener).start(); - } + final var discoveryNodes = clusterService.state().nodes(); + final String[] nodeIds = resolveNodes(request, discoveryNodes); + + new CancellableFanOut() { + final ArrayList taskResponses = new ArrayList<>(); + final ArrayList taskOperationFailures = new ArrayList<>(); + final ArrayList failedNodeExceptions = new ArrayList<>(); + final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout()); + + @Override + protected void sendItemRequest(String nodeId, ActionListener listener) { + final var discoveryNode = discoveryNodes.get(nodeId); + if (discoveryNode == null) { + listener.onFailure(new NoSuchNodeException(nodeId)); + return; + } + + transportService.sendChildRequest( + discoveryNode, + transportNodeAction, + new NodeTaskRequest(request), + task, + transportRequestOptions, + new ActionListenerResponseHandler<>(listener, nodeResponseReader) + ); + } + + @Override + protected void onItemResponse(String nodeId, NodeTasksResponse nodeTasksResponse) { + addAllSynchronized(taskResponses, nodeTasksResponse.results); + addAllSynchronized(taskOperationFailures, nodeTasksResponse.exceptions); + } + + @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter") + private static void addAllSynchronized(List allResults, Collection response) { + if (response.isEmpty() == false) { + synchronized (allResults) { + allResults.addAll(response); + } + } + } + + @Override + protected void onItemFailure(String nodeId, Exception e) { + logger.debug(() -> Strings.format("failed to execute on node [{}]", nodeId), e); + synchronized (failedNodeExceptions) { + failedNodeExceptions.add(new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", e)); + } + } + + @Override + protected TasksResponse onCompletion() { + // ref releases all happen-before here so no need to be synchronized + return newResponse(request, taskResponses, taskOperationFailures, failedNodeExceptions); + } - private void nodeOperation(CancellableTask task, NodeTaskRequest nodeTaskRequest, ActionListener listener) { - TasksRequest request = nodeTaskRequest.tasksRequest; - processTasks(request, ActionListener.wrap(tasks -> nodeOperation(task, listener, request, tasks), listener::onFailure)); + @Override + public String toString() { + return actionName; + } + }.run(task, Iterators.forArray(nodeIds), listener); } + // not an inline method reference to avoid capturing CancellableFanOut.this. + private final Writeable.Reader nodeResponseReader = NodeTasksResponse::new; + private void nodeOperation( - CancellableTask task, + CancellableTask nodeTask, ActionListener listener, TasksRequest request, - List tasks + List operationTasks ) { - if (tasks.isEmpty()) { - listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), emptyList(), emptyList())); - return; - } - AtomicArray> responses = new AtomicArray<>(tasks.size()); - final AtomicInteger counter = new AtomicInteger(tasks.size()); - for (int i = 0; i < tasks.size(); i++) { - final int taskIndex = i; - ActionListener taskListener = new ActionListener() { - @Override - public void onResponse(TaskResponse response) { - responses.setOnce(taskIndex, response == null ? null : new Tuple<>(response, null)); - respondIfFinished(); - } + new CancellableFanOut() { - @Override - public void onFailure(Exception e) { - responses.setOnce(taskIndex, new Tuple<>(null, e)); - respondIfFinished(); + final ArrayList results = new ArrayList<>(operationTasks.size()); + final ArrayList exceptions = new ArrayList<>(); + + @Override + protected void sendItemRequest(OperationTask operationTask, ActionListener listener) { + ActionListener.run(listener, l -> taskOperation(nodeTask, request, operationTask, l)); + } + + @Override + protected void onItemResponse(OperationTask operationTask, TaskResponse taskResponse) { + synchronized (results) { + results.add(taskResponse); } + } - private void respondIfFinished() { - if (counter.decrementAndGet() != 0) { - return; - } - List results = new ArrayList<>(); - List exceptions = new ArrayList<>(); - for (Tuple response : responses.asList()) { - if (response.v1() == null) { - assert response.v2() != null; - exceptions.add( - new TaskOperationFailure(clusterService.localNode().getId(), tasks.get(taskIndex).getId(), response.v2()) - ); - } else { - assert response.v2() == null; - results.add(response.v1()); - } - } - listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions)); + @Override + protected void onItemFailure(OperationTask operationTask, Exception e) { + synchronized (exceptions) { + exceptions.add(new TaskOperationFailure(clusterService.localNode().getId(), operationTask.getId(), e)); } - }; - try { - taskOperation(task, request, tasks.get(taskIndex), taskListener); - } catch (Exception e) { - taskListener.onFailure(e); } - } + + @Override + protected NodeTasksResponse onCompletion() { + // ref releases all happen-before here so no need to be synchronized + return new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions); + } + + @Override + public String toString() { + return transportNodeAction; + } + }.run(nodeTask, operationTasks.iterator(), listener); } protected String[] resolveNodes(TasksRequest request, DiscoveryNodes discoveryNodes) { @@ -192,28 +234,6 @@ protected abstract TasksResponse newResponse( List failedNodeExceptions ); - @SuppressWarnings("unchecked") - protected TasksResponse newResponse(TasksRequest request, AtomicReferenceArray responses) { - List tasks = new ArrayList<>(); - List failedNodeExceptions = new ArrayList<>(); - List taskOperationFailures = new ArrayList<>(); - for (int i = 0; i < responses.length(); i++) { - Object response = responses.get(i); - if (response instanceof FailedNodeException) { - failedNodeExceptions.add((FailedNodeException) response); - } else { - NodeTasksResponse tasksResponse = (NodeTasksResponse) response; - if (tasksResponse.results != null) { - tasks.addAll(tasksResponse.results); - } - if (tasksResponse.exceptions != null) { - taskOperationFailures.addAll(tasksResponse.exceptions); - } - } - } - return newResponse(request, tasks, taskOperationFailures, failedNodeExceptions); - } - /** * Perform the required operation on the task. It is OK start an asynchronous operation or to throw an exception but not both. * @param actionTask The related transport action task. Can be used to create a task ID to handle upstream transport cancellations. @@ -228,120 +248,18 @@ protected abstract void taskOperation( ActionListener listener ); - private class AsyncAction { - - private final TasksRequest request; - private final String[] nodesIds; - private final DiscoveryNode[] nodes; - private final ActionListener listener; - private final AtomicReferenceArray responses; - private final AtomicInteger counter = new AtomicInteger(); - private final Task task; - - private AsyncAction(Task task, TasksRequest request, ActionListener listener) { - this.task = task; - this.request = request; - this.listener = listener; - final DiscoveryNodes discoveryNodes = clusterService.state().nodes(); - this.nodesIds = resolveNodes(request, discoveryNodes); - Map nodes = discoveryNodes.getNodes(); - this.nodes = new DiscoveryNode[nodesIds.length]; - for (int i = 0; i < this.nodesIds.length; i++) { - this.nodes[i] = nodes.get(this.nodesIds[i]); - } - this.responses = new AtomicReferenceArray<>(this.nodesIds.length); - } - - private void start() { - if (nodesIds.length == 0) { - // nothing to do - try { - listener.onResponse(newResponse(request, responses)); - } catch (Exception e) { - logger.debug("failed to generate empty response", e); - listener.onFailure(e); - } - } else { - final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout()); - for (int i = 0; i < nodesIds.length; i++) { - final String nodeId = nodesIds[i]; - final int idx = i; - final DiscoveryNode node = nodes[i]; - try { - if (node == null) { - onFailure(idx, nodeId, new NoSuchNodeException(nodeId)); - } else { - NodeTaskRequest nodeRequest = new NodeTaskRequest(request); - nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId()); - transportService.sendRequest( - node, - transportNodeAction, - nodeRequest, - transportRequestOptions, - new TransportResponseHandler() { - @Override - public NodeTasksResponse read(StreamInput in) throws IOException { - return new NodeTasksResponse(in); - } - - @Override - public void handleResponse(NodeTasksResponse response) { - onOperation(idx, response); - } - - @Override - public void handleException(TransportException exp) { - onFailure(idx, node.getId(), exp); - } - } - ); - } - } catch (Exception e) { - onFailure(idx, nodeId, e); - } - } - } - } - - private void onOperation(int idx, NodeTasksResponse nodeResponse) { - responses.set(idx, nodeResponse); - if (counter.incrementAndGet() == responses.length()) { - finishHim(); - } - } - - private void onFailure(int idx, String nodeId, Throwable t) { - logger.debug(() -> "failed to execute on node [" + nodeId + "]", t); - - responses.set(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t)); - - if (counter.incrementAndGet() == responses.length()) { - finishHim(); - } - } - - private void finishHim() { - if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) { - return; - } - TasksResponse finalResponse; - try { - finalResponse = newResponse(request, responses); - } catch (Exception e) { - logger.debug("failed to combine responses from nodes", e); - listener.onFailure(e); - return; - } - listener.onResponse(finalResponse); - } - } - class NodeTransportHandler implements TransportRequestHandler { @Override public void messageReceived(final NodeTaskRequest request, final TransportChannel channel, Task task) throws Exception { assert task instanceof CancellableTask; - nodeOperation((CancellableTask) task, request, new ChannelActionListener<>(channel)); + TasksRequest tasksRequest = request.tasksRequest; + processTasks( + tasksRequest, + new ChannelActionListener(channel).delegateFailure( + (l, tasks) -> nodeOperation((CancellableTask) task, l, tasksRequest, tasks) + ) + ); } } diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java index cbd4a10bda3d2..f0f2e8c174ac6 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java @@ -7,6 +7,7 @@ */ package org.elasticsearch.action.admin.cluster.node.tasks; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.FailedNodeException; @@ -40,6 +41,7 @@ import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.test.ReachabilityChecker; import org.elasticsearch.test.tasks.MockTaskManager; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportRequest; @@ -55,9 +57,12 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.stream.Collectors; import static org.elasticsearch.action.support.PlainActionFuture.newFuture; @@ -68,6 +73,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; public class TransportTasksActionTests extends TaskManagerTestCase { @@ -674,6 +680,152 @@ protected void taskOperation( assertEquals(0, responses.failureCount()); } + public void testTaskResponsesDiscardedOnCancellation() throws Exception { + setupTestNodes(Settings.EMPTY); + connectNodes(testNodes); + CountDownLatch blockedActionLatch = new CountDownLatch(1); + ActionFuture future = startBlockingTestNodesAction(blockedActionLatch); + + final var taskResponseListeners = new LinkedBlockingQueue>(); + final var taskResponseListenersCountDown = new CountDownLatch(2); // test action plus the list[n] action + + final TestTasksAction tasksAction = new TestTasksAction( + "internal:testTasksAction", + testNodes[0].clusterService, + testNodes[0].transportService + ) { + @Override + protected void taskOperation( + CancellableTask actionTask, + TestTasksRequest request, + Task task, + ActionListener listener + ) { + taskResponseListeners.add(listener); + taskResponseListenersCountDown.countDown(); + } + }; + + TestTasksRequest testTasksRequest = new TestTasksRequest(); + testTasksRequest.setNodes(testNodes[0].getNodeId()); // only local node + PlainActionFuture taskFuture = newFuture(); + CancellableTask task = (CancellableTask) testNodes[0].transportService.getTaskManager() + .registerAndExecute( + "direct", + tasksAction, + testTasksRequest, + testNodes[0].transportService.getLocalNodeConnection(), + taskFuture + ); + safeAwait(taskResponseListenersCountDown); + + final var reachabilityChecker = new ReachabilityChecker(); + + final var listener0 = Objects.requireNonNull(taskResponseListeners.poll()); + if (randomBoolean()) { + listener0.onResponse(reachabilityChecker.register(new TestTaskResponse("status"))); + } else { + listener0.onFailure(reachabilityChecker.register(new ElasticsearchException("simulated"))); + } + reachabilityChecker.checkReachable(); + + PlainActionFuture.get( + fut -> testNodes[0].transportService.getTaskManager().cancelTaskAndDescendants(task, "test", false, fut), + 10, + TimeUnit.SECONDS + ); + + reachabilityChecker.ensureUnreachable(); + + while (true) { + final var listener = taskResponseListeners.poll(); + if (listener == null) { + break; + } + if (randomBoolean()) { + listener.onResponse(reachabilityChecker.register(new TestTaskResponse("status"))); + } else { + listener.onFailure(reachabilityChecker.register(new ElasticsearchException("simulated"))); + } + reachabilityChecker.ensureUnreachable(); + } + + expectThrows(TaskCancelledException.class, taskFuture::actionGet); + + blockedActionLatch.countDown(); + NodesResponse responses = future.get(10, TimeUnit.SECONDS); + assertEquals(0, responses.failureCount()); + } + + public void testNodeResponsesDiscardedOnCancellation() { + setupTestNodes(Settings.EMPTY); + connectNodes(testNodes); + + final var taskResponseListeners = new AtomicReferenceArray>(testNodes.length); + final var taskResponseListenersCountDown = new CountDownLatch(testNodes.length); // one list[n] action per node + final var tasksActions = new TestTasksAction[testNodes.length]; + for (int i = 0; i < testNodes.length; i++) { + final var nodeIndex = i; + tasksActions[i] = new TestTasksAction("internal:testTasksAction", testNodes[i].clusterService, testNodes[i].transportService) { + @Override + protected void taskOperation( + CancellableTask actionTask, + TestTasksRequest request, + Task task, + ActionListener listener + ) { + assertThat(taskResponseListeners.getAndSet(nodeIndex, ActionListener.notifyOnce(listener)), nullValue()); + taskResponseListenersCountDown.countDown(); + } + }; + } + + TestTasksRequest testTasksRequest = new TestTasksRequest(); + testTasksRequest.setActions("internal:testTasksAction[n]"); + PlainActionFuture taskFuture = newFuture(); + CancellableTask task = (CancellableTask) testNodes[0].transportService.getTaskManager() + .registerAndExecute( + "direct", + tasksActions[0], + testTasksRequest, + testNodes[0].transportService.getLocalNodeConnection(), + taskFuture + ); + safeAwait(taskResponseListenersCountDown); + + final var reachabilityChecker = new ReachabilityChecker(); + + if (randomBoolean()) { + // local node does not de/serialize node-level response so retains references to the task-level response + if (randomBoolean()) { + taskResponseListeners.get(0).onResponse(reachabilityChecker.register(new TestTaskResponse("status"))); + } else { + taskResponseListeners.get(0).onFailure(reachabilityChecker.register(new ElasticsearchException("simulated"))); + } + reachabilityChecker.checkReachable(); + } + + PlainActionFuture.get( + fut -> testNodes[0].transportService.getTaskManager().cancelTaskAndDescendants(task, "test", false, fut), + 10, + TimeUnit.SECONDS + ); + + reachabilityChecker.ensureUnreachable(); + assertFalse(taskFuture.isDone()); + + for (int i = 0; i < testNodes.length; i++) { + if (randomBoolean()) { + taskResponseListeners.get(i).onResponse(reachabilityChecker.register(new TestTaskResponse("status"))); + } else { + taskResponseListeners.get(i).onFailure(reachabilityChecker.register(new ElasticsearchException("simulated"))); + } + reachabilityChecker.ensureUnreachable(); + } + + expectThrows(TaskCancelledException.class, taskFuture::actionGet); + } + public void testTaskLevelActionFailures() throws Exception { setupTestNodes(Settings.EMPTY); connectNodes(testNodes);