Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/96279.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 96279
summary: Improve cancellability in `TransportTasksAction`
area: Task Management
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,36 @@

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.NoSuchNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.CancellableFanOut;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;

import static java.util.Collections.emptyList;

/**
* The base class for transport actions that are interacting with currently running tasks.
Expand Down Expand Up @@ -85,67 +81,113 @@ protected TransportTasksAction(

@Override
protected void doExecute(Task task, TasksRequest request, ActionListener<TasksResponse> listener) {
new AsyncAction(task, request, listener).start();
}
final var discoveryNodes = clusterService.state().nodes();
final String[] nodeIds = resolveNodes(request, discoveryNodes);

new CancellableFanOut<String, NodeTasksResponse, TasksResponse>() {
final ArrayList<TaskResponse> taskResponses = new ArrayList<>();
final ArrayList<TaskOperationFailure> taskOperationFailures = new ArrayList<>();
final ArrayList<FailedNodeException> failedNodeExceptions = new ArrayList<>();
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout());

@Override
protected void sendItemRequest(String nodeId, ActionListener<NodeTasksResponse> listener) {
final var discoveryNode = discoveryNodes.get(nodeId);
if (discoveryNode == null) {
listener.onFailure(new NoSuchNodeException(nodeId));
return;
}

transportService.sendChildRequest(
discoveryNode,
transportNodeAction,
new NodeTaskRequest(request),
task,
transportRequestOptions,
new ActionListenerResponseHandler<>(listener, nodeResponseReader)
);
}

@Override
protected void onItemResponse(String nodeId, NodeTasksResponse nodeTasksResponse) {
addAllSynchronized(taskResponses, nodeTasksResponse.results);
addAllSynchronized(taskOperationFailures, nodeTasksResponse.exceptions);
}

@SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
private static <T> void addAllSynchronized(List<T> allResults, Collection<T> response) {
if (response.isEmpty() == false) {
synchronized (allResults) {
allResults.addAll(response);
}
}
}

@Override
protected void onItemFailure(String nodeId, Exception e) {
logger.debug(() -> Strings.format("failed to execute on node [{}]", nodeId), e);
synchronized (failedNodeExceptions) {
failedNodeExceptions.add(new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", e));
}
}

@Override
protected TasksResponse onCompletion() {
// ref releases all happen-before here so no need to be synchronized
return newResponse(request, taskResponses, taskOperationFailures, failedNodeExceptions);
}

private void nodeOperation(CancellableTask task, NodeTaskRequest nodeTaskRequest, ActionListener<NodeTasksResponse> listener) {
TasksRequest request = nodeTaskRequest.tasksRequest;
processTasks(request, ActionListener.wrap(tasks -> nodeOperation(task, listener, request, tasks), listener::onFailure));
@Override
public String toString() {
return actionName;
}
}.run(task, Iterators.forArray(nodeIds), listener);
}

// not an inline method reference to avoid capturing CancellableFanOut.this.
private final Writeable.Reader<NodeTasksResponse> nodeResponseReader = NodeTasksResponse::new;

private void nodeOperation(
CancellableTask task,
CancellableTask nodeTask,
ActionListener<NodeTasksResponse> listener,
TasksRequest request,
List<OperationTask> tasks
List<OperationTask> operationTasks
) {
if (tasks.isEmpty()) {
listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), emptyList(), emptyList()));
return;
}
AtomicArray<Tuple<TaskResponse, Exception>> responses = new AtomicArray<>(tasks.size());
final AtomicInteger counter = new AtomicInteger(tasks.size());
for (int i = 0; i < tasks.size(); i++) {
final int taskIndex = i;
ActionListener<TaskResponse> taskListener = new ActionListener<TaskResponse>() {
@Override
public void onResponse(TaskResponse response) {
responses.setOnce(taskIndex, response == null ? null : new Tuple<>(response, null));
respondIfFinished();
}
new CancellableFanOut<OperationTask, TaskResponse, NodeTasksResponse>() {

@Override
public void onFailure(Exception e) {
responses.setOnce(taskIndex, new Tuple<>(null, e));
respondIfFinished();
final ArrayList<TaskResponse> results = new ArrayList<>(operationTasks.size());
final ArrayList<TaskOperationFailure> exceptions = new ArrayList<>();

@Override
protected void sendItemRequest(OperationTask operationTask, ActionListener<TaskResponse> listener) {
ActionListener.run(listener, l -> taskOperation(nodeTask, request, operationTask, l));
}

@Override
protected void onItemResponse(OperationTask operationTask, TaskResponse taskResponse) {
synchronized (results) {
results.add(taskResponse);
}
}

private void respondIfFinished() {
if (counter.decrementAndGet() != 0) {
return;
}
List<TaskResponse> results = new ArrayList<>();
List<TaskOperationFailure> exceptions = new ArrayList<>();
for (Tuple<TaskResponse, Exception> response : responses.asList()) {
if (response.v1() == null) {
assert response.v2() != null;
exceptions.add(
new TaskOperationFailure(clusterService.localNode().getId(), tasks.get(taskIndex).getId(), response.v2())
);
} else {
assert response.v2() == null;
results.add(response.v1());
}
}
listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions));
@Override
protected void onItemFailure(OperationTask operationTask, Exception e) {
synchronized (exceptions) {
exceptions.add(new TaskOperationFailure(clusterService.localNode().getId(), operationTask.getId(), e));
}
};
try {
taskOperation(task, request, tasks.get(taskIndex), taskListener);
} catch (Exception e) {
taskListener.onFailure(e);
}
}

@Override
protected NodeTasksResponse onCompletion() {
// ref releases all happen-before here so no need to be synchronized
return new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions);
}

@Override
public String toString() {
return transportNodeAction;
}
}.run(nodeTask, operationTasks.iterator(), listener);
}

protected String[] resolveNodes(TasksRequest request, DiscoveryNodes discoveryNodes) {
Expand Down Expand Up @@ -192,28 +234,6 @@ protected abstract TasksResponse newResponse(
List<FailedNodeException> failedNodeExceptions
);

@SuppressWarnings("unchecked")
protected TasksResponse newResponse(TasksRequest request, AtomicReferenceArray<?> responses) {
List<TaskResponse> tasks = new ArrayList<>();
List<FailedNodeException> failedNodeExceptions = new ArrayList<>();
List<TaskOperationFailure> taskOperationFailures = new ArrayList<>();
for (int i = 0; i < responses.length(); i++) {
Object response = responses.get(i);
if (response instanceof FailedNodeException) {
failedNodeExceptions.add((FailedNodeException) response);
} else {
NodeTasksResponse tasksResponse = (NodeTasksResponse) response;
if (tasksResponse.results != null) {
tasks.addAll(tasksResponse.results);
}
if (tasksResponse.exceptions != null) {
taskOperationFailures.addAll(tasksResponse.exceptions);
}
}
}
return newResponse(request, tasks, taskOperationFailures, failedNodeExceptions);
}

/**
* Perform the required operation on the task. It is OK start an asynchronous operation or to throw an exception but not both.
* @param actionTask The related transport action task. Can be used to create a task ID to handle upstream transport cancellations.
Expand All @@ -228,120 +248,18 @@ protected abstract void taskOperation(
ActionListener<TaskResponse> listener
);

private class AsyncAction {

private final TasksRequest request;
private final String[] nodesIds;
private final DiscoveryNode[] nodes;
private final ActionListener<TasksResponse> listener;
private final AtomicReferenceArray<Object> responses;
private final AtomicInteger counter = new AtomicInteger();
private final Task task;

private AsyncAction(Task task, TasksRequest request, ActionListener<TasksResponse> listener) {
this.task = task;
this.request = request;
this.listener = listener;
final DiscoveryNodes discoveryNodes = clusterService.state().nodes();
this.nodesIds = resolveNodes(request, discoveryNodes);
Map<String, DiscoveryNode> nodes = discoveryNodes.getNodes();
this.nodes = new DiscoveryNode[nodesIds.length];
for (int i = 0; i < this.nodesIds.length; i++) {
this.nodes[i] = nodes.get(this.nodesIds[i]);
}
this.responses = new AtomicReferenceArray<>(this.nodesIds.length);
}

private void start() {
if (nodesIds.length == 0) {
// nothing to do
try {
listener.onResponse(newResponse(request, responses));
} catch (Exception e) {
logger.debug("failed to generate empty response", e);
listener.onFailure(e);
}
} else {
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout());
for (int i = 0; i < nodesIds.length; i++) {
final String nodeId = nodesIds[i];
final int idx = i;
final DiscoveryNode node = nodes[i];
try {
if (node == null) {
onFailure(idx, nodeId, new NoSuchNodeException(nodeId));
} else {
NodeTaskRequest nodeRequest = new NodeTaskRequest(request);
nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId());
transportService.sendRequest(
node,
transportNodeAction,
nodeRequest,
transportRequestOptions,
new TransportResponseHandler<NodeTasksResponse>() {
@Override
public NodeTasksResponse read(StreamInput in) throws IOException {
return new NodeTasksResponse(in);
}

@Override
public void handleResponse(NodeTasksResponse response) {
onOperation(idx, response);
}

@Override
public void handleException(TransportException exp) {
onFailure(idx, node.getId(), exp);
}
}
);
}
} catch (Exception e) {
onFailure(idx, nodeId, e);
}
}
}
}

private void onOperation(int idx, NodeTasksResponse nodeResponse) {
responses.set(idx, nodeResponse);
if (counter.incrementAndGet() == responses.length()) {
finishHim();
}
}

private void onFailure(int idx, String nodeId, Throwable t) {
logger.debug(() -> "failed to execute on node [" + nodeId + "]", t);

responses.set(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t));

if (counter.incrementAndGet() == responses.length()) {
finishHim();
}
}

private void finishHim() {
if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) {
return;
}
TasksResponse finalResponse;
try {
finalResponse = newResponse(request, responses);
} catch (Exception e) {
logger.debug("failed to combine responses from nodes", e);
listener.onFailure(e);
return;
}
listener.onResponse(finalResponse);
}
}

class NodeTransportHandler implements TransportRequestHandler<NodeTaskRequest> {

@Override
public void messageReceived(final NodeTaskRequest request, final TransportChannel channel, Task task) throws Exception {
assert task instanceof CancellableTask;
nodeOperation((CancellableTask) task, request, new ChannelActionListener<>(channel));
TasksRequest tasksRequest = request.tasksRequest;
processTasks(
tasksRequest,
new ChannelActionListener<NodeTasksResponse>(channel).delegateFailure(
(l, tasks) -> nodeOperation((CancellableTask) task, l, tasksRequest, tasks)
)
);
}
}

Expand Down
Loading