Skip to content

Commit 2513104

Browse files
authored
Improve cancellability in TransportTasksAction (#96279)
Each `TransportTasksAction` fans-out to multiple nodes, accumulates responses and retains them until all the nodes have responded, and then converts the responses into a final result. Similarly to #92987 and #93484, we should accumulate the responses in a structure that doesn't require so much copying later on, and should drop the received responses if the task is cancelled while some nodes' responses are still pending.
1 parent 0cf80f4 commit 2513104

File tree

3 files changed

+264
-189
lines changed

3 files changed

+264
-189
lines changed

docs/changelog/96279.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 96279
2+
summary: Improve cancellability in `TransportTasksAction`
3+
area: Task Management
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java

Lines changed: 107 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,36 @@
1010

1111
import org.elasticsearch.ResourceNotFoundException;
1212
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.action.ActionListenerResponseHandler;
1314
import org.elasticsearch.action.FailedNodeException;
1415
import org.elasticsearch.action.NoSuchNodeException;
1516
import org.elasticsearch.action.TaskOperationFailure;
1617
import org.elasticsearch.action.support.ActionFilters;
18+
import org.elasticsearch.action.support.CancellableFanOut;
1719
import org.elasticsearch.action.support.ChannelActionListener;
1820
import org.elasticsearch.action.support.HandledTransportAction;
19-
import org.elasticsearch.cluster.node.DiscoveryNode;
2021
import org.elasticsearch.cluster.node.DiscoveryNodes;
2122
import org.elasticsearch.cluster.service.ClusterService;
23+
import org.elasticsearch.common.Strings;
24+
import org.elasticsearch.common.collect.Iterators;
2225
import org.elasticsearch.common.io.stream.StreamInput;
2326
import org.elasticsearch.common.io.stream.StreamOutput;
2427
import org.elasticsearch.common.io.stream.Writeable;
25-
import org.elasticsearch.common.util.concurrent.AtomicArray;
26-
import org.elasticsearch.core.Tuple;
2728
import org.elasticsearch.tasks.CancellableTask;
2829
import org.elasticsearch.tasks.Task;
2930
import org.elasticsearch.tasks.TaskId;
3031
import org.elasticsearch.transport.TransportChannel;
31-
import org.elasticsearch.transport.TransportException;
3232
import org.elasticsearch.transport.TransportRequest;
3333
import org.elasticsearch.transport.TransportRequestHandler;
3434
import org.elasticsearch.transport.TransportRequestOptions;
3535
import org.elasticsearch.transport.TransportResponse;
36-
import org.elasticsearch.transport.TransportResponseHandler;
3736
import org.elasticsearch.transport.TransportService;
3837

3938
import java.io.IOException;
4039
import java.util.ArrayList;
40+
import java.util.Collection;
4141
import java.util.List;
4242
import java.util.Map;
43-
import java.util.concurrent.atomic.AtomicInteger;
44-
import java.util.concurrent.atomic.AtomicReferenceArray;
45-
46-
import static java.util.Collections.emptyList;
4743

4844
/**
4945
* The base class for transport actions that are interacting with currently running tasks.
@@ -85,67 +81,113 @@ protected TransportTasksAction(
8581

8682
@Override
8783
protected void doExecute(Task task, TasksRequest request, ActionListener<TasksResponse> listener) {
88-
new AsyncAction(task, request, listener).start();
89-
}
84+
final var discoveryNodes = clusterService.state().nodes();
85+
final String[] nodeIds = resolveNodes(request, discoveryNodes);
86+
87+
new CancellableFanOut<String, NodeTasksResponse, TasksResponse>() {
88+
final ArrayList<TaskResponse> taskResponses = new ArrayList<>();
89+
final ArrayList<TaskOperationFailure> taskOperationFailures = new ArrayList<>();
90+
final ArrayList<FailedNodeException> failedNodeExceptions = new ArrayList<>();
91+
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout());
92+
93+
@Override
94+
protected void sendItemRequest(String nodeId, ActionListener<NodeTasksResponse> listener) {
95+
final var discoveryNode = discoveryNodes.get(nodeId);
96+
if (discoveryNode == null) {
97+
listener.onFailure(new NoSuchNodeException(nodeId));
98+
return;
99+
}
100+
101+
transportService.sendChildRequest(
102+
discoveryNode,
103+
transportNodeAction,
104+
new NodeTaskRequest(request),
105+
task,
106+
transportRequestOptions,
107+
new ActionListenerResponseHandler<>(listener, nodeResponseReader)
108+
);
109+
}
110+
111+
@Override
112+
protected void onItemResponse(String nodeId, NodeTasksResponse nodeTasksResponse) {
113+
addAllSynchronized(taskResponses, nodeTasksResponse.results);
114+
addAllSynchronized(taskOperationFailures, nodeTasksResponse.exceptions);
115+
}
116+
117+
@SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
118+
private static <T> void addAllSynchronized(List<T> allResults, Collection<T> response) {
119+
if (response.isEmpty() == false) {
120+
synchronized (allResults) {
121+
allResults.addAll(response);
122+
}
123+
}
124+
}
125+
126+
@Override
127+
protected void onItemFailure(String nodeId, Exception e) {
128+
logger.debug(() -> Strings.format("failed to execute on node [{}]", nodeId), e);
129+
synchronized (failedNodeExceptions) {
130+
failedNodeExceptions.add(new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", e));
131+
}
132+
}
133+
134+
@Override
135+
protected TasksResponse onCompletion() {
136+
// ref releases all happen-before here so no need to be synchronized
137+
return newResponse(request, taskResponses, taskOperationFailures, failedNodeExceptions);
138+
}
90139

91-
private void nodeOperation(CancellableTask task, NodeTaskRequest nodeTaskRequest, ActionListener<NodeTasksResponse> listener) {
92-
TasksRequest request = nodeTaskRequest.tasksRequest;
93-
processTasks(request, ActionListener.wrap(tasks -> nodeOperation(task, listener, request, tasks), listener::onFailure));
140+
@Override
141+
public String toString() {
142+
return actionName;
143+
}
144+
}.run(task, Iterators.forArray(nodeIds), listener);
94145
}
95146

147+
// not an inline method reference to avoid capturing CancellableFanOut.this.
148+
private final Writeable.Reader<NodeTasksResponse> nodeResponseReader = NodeTasksResponse::new;
149+
96150
private void nodeOperation(
97-
CancellableTask task,
151+
CancellableTask nodeTask,
98152
ActionListener<NodeTasksResponse> listener,
99153
TasksRequest request,
100-
List<OperationTask> tasks
154+
List<OperationTask> operationTasks
101155
) {
102-
if (tasks.isEmpty()) {
103-
listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), emptyList(), emptyList()));
104-
return;
105-
}
106-
AtomicArray<Tuple<TaskResponse, Exception>> responses = new AtomicArray<>(tasks.size());
107-
final AtomicInteger counter = new AtomicInteger(tasks.size());
108-
for (int i = 0; i < tasks.size(); i++) {
109-
final int taskIndex = i;
110-
ActionListener<TaskResponse> taskListener = new ActionListener<TaskResponse>() {
111-
@Override
112-
public void onResponse(TaskResponse response) {
113-
responses.setOnce(taskIndex, response == null ? null : new Tuple<>(response, null));
114-
respondIfFinished();
115-
}
156+
new CancellableFanOut<OperationTask, TaskResponse, NodeTasksResponse>() {
116157

117-
@Override
118-
public void onFailure(Exception e) {
119-
responses.setOnce(taskIndex, new Tuple<>(null, e));
120-
respondIfFinished();
158+
final ArrayList<TaskResponse> results = new ArrayList<>(operationTasks.size());
159+
final ArrayList<TaskOperationFailure> exceptions = new ArrayList<>();
160+
161+
@Override
162+
protected void sendItemRequest(OperationTask operationTask, ActionListener<TaskResponse> listener) {
163+
ActionListener.run(listener, l -> taskOperation(nodeTask, request, operationTask, l));
164+
}
165+
166+
@Override
167+
protected void onItemResponse(OperationTask operationTask, TaskResponse taskResponse) {
168+
synchronized (results) {
169+
results.add(taskResponse);
121170
}
171+
}
122172

123-
private void respondIfFinished() {
124-
if (counter.decrementAndGet() != 0) {
125-
return;
126-
}
127-
List<TaskResponse> results = new ArrayList<>();
128-
List<TaskOperationFailure> exceptions = new ArrayList<>();
129-
for (Tuple<TaskResponse, Exception> response : responses.asList()) {
130-
if (response.v1() == null) {
131-
assert response.v2() != null;
132-
exceptions.add(
133-
new TaskOperationFailure(clusterService.localNode().getId(), tasks.get(taskIndex).getId(), response.v2())
134-
);
135-
} else {
136-
assert response.v2() == null;
137-
results.add(response.v1());
138-
}
139-
}
140-
listener.onResponse(new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions));
173+
@Override
174+
protected void onItemFailure(OperationTask operationTask, Exception e) {
175+
synchronized (exceptions) {
176+
exceptions.add(new TaskOperationFailure(clusterService.localNode().getId(), operationTask.getId(), e));
141177
}
142-
};
143-
try {
144-
taskOperation(task, request, tasks.get(taskIndex), taskListener);
145-
} catch (Exception e) {
146-
taskListener.onFailure(e);
147178
}
148-
}
179+
180+
@Override
181+
protected NodeTasksResponse onCompletion() {
182+
// ref releases all happen-before here so no need to be synchronized
183+
return new NodeTasksResponse(clusterService.localNode().getId(), results, exceptions);
184+
}
185+
186+
@Override
187+
public String toString() {
188+
return transportNodeAction;
189+
}
190+
}.run(nodeTask, operationTasks.iterator(), listener);
149191
}
150192

151193
protected String[] resolveNodes(TasksRequest request, DiscoveryNodes discoveryNodes) {
@@ -192,28 +234,6 @@ protected abstract TasksResponse newResponse(
192234
List<FailedNodeException> failedNodeExceptions
193235
);
194236

195-
@SuppressWarnings("unchecked")
196-
protected TasksResponse newResponse(TasksRequest request, AtomicReferenceArray<?> responses) {
197-
List<TaskResponse> tasks = new ArrayList<>();
198-
List<FailedNodeException> failedNodeExceptions = new ArrayList<>();
199-
List<TaskOperationFailure> taskOperationFailures = new ArrayList<>();
200-
for (int i = 0; i < responses.length(); i++) {
201-
Object response = responses.get(i);
202-
if (response instanceof FailedNodeException) {
203-
failedNodeExceptions.add((FailedNodeException) response);
204-
} else {
205-
NodeTasksResponse tasksResponse = (NodeTasksResponse) response;
206-
if (tasksResponse.results != null) {
207-
tasks.addAll(tasksResponse.results);
208-
}
209-
if (tasksResponse.exceptions != null) {
210-
taskOperationFailures.addAll(tasksResponse.exceptions);
211-
}
212-
}
213-
}
214-
return newResponse(request, tasks, taskOperationFailures, failedNodeExceptions);
215-
}
216-
217237
/**
218238
* Perform the required operation on the task. It is OK start an asynchronous operation or to throw an exception but not both.
219239
* @param actionTask The related transport action task. Can be used to create a task ID to handle upstream transport cancellations.
@@ -228,120 +248,18 @@ protected abstract void taskOperation(
228248
ActionListener<TaskResponse> listener
229249
);
230250

231-
private class AsyncAction {
232-
233-
private final TasksRequest request;
234-
private final String[] nodesIds;
235-
private final DiscoveryNode[] nodes;
236-
private final ActionListener<TasksResponse> listener;
237-
private final AtomicReferenceArray<Object> responses;
238-
private final AtomicInteger counter = new AtomicInteger();
239-
private final Task task;
240-
241-
private AsyncAction(Task task, TasksRequest request, ActionListener<TasksResponse> listener) {
242-
this.task = task;
243-
this.request = request;
244-
this.listener = listener;
245-
final DiscoveryNodes discoveryNodes = clusterService.state().nodes();
246-
this.nodesIds = resolveNodes(request, discoveryNodes);
247-
Map<String, DiscoveryNode> nodes = discoveryNodes.getNodes();
248-
this.nodes = new DiscoveryNode[nodesIds.length];
249-
for (int i = 0; i < this.nodesIds.length; i++) {
250-
this.nodes[i] = nodes.get(this.nodesIds[i]);
251-
}
252-
this.responses = new AtomicReferenceArray<>(this.nodesIds.length);
253-
}
254-
255-
private void start() {
256-
if (nodesIds.length == 0) {
257-
// nothing to do
258-
try {
259-
listener.onResponse(newResponse(request, responses));
260-
} catch (Exception e) {
261-
logger.debug("failed to generate empty response", e);
262-
listener.onFailure(e);
263-
}
264-
} else {
265-
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.getTimeout());
266-
for (int i = 0; i < nodesIds.length; i++) {
267-
final String nodeId = nodesIds[i];
268-
final int idx = i;
269-
final DiscoveryNode node = nodes[i];
270-
try {
271-
if (node == null) {
272-
onFailure(idx, nodeId, new NoSuchNodeException(nodeId));
273-
} else {
274-
NodeTaskRequest nodeRequest = new NodeTaskRequest(request);
275-
nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId());
276-
transportService.sendRequest(
277-
node,
278-
transportNodeAction,
279-
nodeRequest,
280-
transportRequestOptions,
281-
new TransportResponseHandler<NodeTasksResponse>() {
282-
@Override
283-
public NodeTasksResponse read(StreamInput in) throws IOException {
284-
return new NodeTasksResponse(in);
285-
}
286-
287-
@Override
288-
public void handleResponse(NodeTasksResponse response) {
289-
onOperation(idx, response);
290-
}
291-
292-
@Override
293-
public void handleException(TransportException exp) {
294-
onFailure(idx, node.getId(), exp);
295-
}
296-
}
297-
);
298-
}
299-
} catch (Exception e) {
300-
onFailure(idx, nodeId, e);
301-
}
302-
}
303-
}
304-
}
305-
306-
private void onOperation(int idx, NodeTasksResponse nodeResponse) {
307-
responses.set(idx, nodeResponse);
308-
if (counter.incrementAndGet() == responses.length()) {
309-
finishHim();
310-
}
311-
}
312-
313-
private void onFailure(int idx, String nodeId, Throwable t) {
314-
logger.debug(() -> "failed to execute on node [" + nodeId + "]", t);
315-
316-
responses.set(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t));
317-
318-
if (counter.incrementAndGet() == responses.length()) {
319-
finishHim();
320-
}
321-
}
322-
323-
private void finishHim() {
324-
if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) {
325-
return;
326-
}
327-
TasksResponse finalResponse;
328-
try {
329-
finalResponse = newResponse(request, responses);
330-
} catch (Exception e) {
331-
logger.debug("failed to combine responses from nodes", e);
332-
listener.onFailure(e);
333-
return;
334-
}
335-
listener.onResponse(finalResponse);
336-
}
337-
}
338-
339251
class NodeTransportHandler implements TransportRequestHandler<NodeTaskRequest> {
340252

341253
@Override
342254
public void messageReceived(final NodeTaskRequest request, final TransportChannel channel, Task task) throws Exception {
343255
assert task instanceof CancellableTask;
344-
nodeOperation((CancellableTask) task, request, new ChannelActionListener<>(channel));
256+
TasksRequest tasksRequest = request.tasksRequest;
257+
processTasks(
258+
tasksRequest,
259+
new ChannelActionListener<NodeTasksResponse>(channel).delegateFailure(
260+
(l, tasks) -> nodeOperation((CancellableTask) task, l, tasksRequest, tasks)
261+
)
262+
);
345263
}
346264
}
347265

0 commit comments

Comments
 (0)