Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,26 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.NodeResponseTracker;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.common.util.concurrent.RunOnce;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
Expand All @@ -38,6 +39,8 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.core.Strings.format;

public abstract class TransportNodesAction<
NodesRequest extends BaseNodesRequest<NodesRequest>,
NodesResponse extends BaseNodesResponse<?>,
Expand Down Expand Up @@ -85,7 +88,7 @@ protected TransportNodesAction(
this.nodeResponseClass = Objects.requireNonNull(nodeResponseClass);

this.transportNodeAction = actionName + "[n]";
this.finalExecutor = finalExecutor;
this.finalExecutor = finalExecutor.equals(ThreadPool.Names.SAME) ? ThreadPool.Names.GENERIC : finalExecutor;
transportService.registerRequestHandler(transportNodeAction, nodeExecutor, nodeRequest, new NodeTransportHandler());
}

Expand Down Expand Up @@ -123,40 +126,89 @@ protected TransportNodesAction(

@Override
protected void doExecute(Task task, NodesRequest request, ActionListener<NodesResponse> listener) {
new AsyncAction(task, request, listener).start();
}
if (request.concreteNodes() == null) {
resolveRequest(request, clusterService.state());
assert request.concreteNodes() != null;
}

/**
* Map the responses into {@code nodeResponseClass} responses and {@link FailedNodeException}s, convert to a {@link NodesResponse} and
* pass it to the listener. Fails the listener with a {@link NullPointerException} if {@code nodesResponses} is null.
*
* @param request The associated request.
* @param nodeResponseTracker All node-level responses collected so far
* @throws NodeResponseTracker.DiscardedResponsesException if {@code nodeResponseTracker} has already discarded the intermediate results
* @see #newResponseAsync(Task, BaseNodesRequest, List, List, ActionListener)
*/
// exposed for tests
void newResponse(Task task, NodesRequest request, NodeResponseTracker nodeResponseTracker, ActionListener<NodesResponse> listener)
throws NodeResponseTracker.DiscardedResponsesException {
final var responses = new ArrayList<NodeResponse>(request.concreteNodes().length);
final var exceptions = new ArrayList<FailedNodeException>(0);

if (nodeResponseTracker == null) {
listener.onFailure(new NullPointerException("nodesResponses"));
return;
final var resultListener = new ListenableFuture<NodesResponse>();
final var resultListenerCompleter = new RunOnce(() -> {
if (task instanceof CancellableTask cancellableTask) {
if (cancellableTask.notifyIfCancelled(resultListener)) {
return;
}
}
// ref releases all happen-before here so no need to be synchronized
threadPool.executor(finalExecutor)
.execute(ActionRunnable.wrap(resultListener, l -> newResponseAsync(task, request, responses, exceptions, l)));
});

final var nodeCancellationListener = new ListenableFuture<NodeResponse>(); // collects node listeners & completes them if cancelled
if (task instanceof CancellableTask cancellableTask) {
cancellableTask.addListener(() -> {
assert cancellableTask.isCancelled();
resultListenerCompleter.run();
cancellableTask.notifyIfCancelled(nodeCancellationListener);
});
}

final List<NodeResponse> responses = new ArrayList<>();
final List<FailedNodeException> failures = new ArrayList<>();
final var transportRequestOptions = TransportRequestOptions.timeout(request.timeout());

try (var refs = new RefCountingRunnable(() -> {
resultListener.addListener(listener);
resultListenerCompleter.run();
})) {
for (final var node : request.concreteNodes()) {
final ActionListener<NodeResponse> nodeResponseListener = ActionListener.notifyOnce(new ActionListener<>() {
@Override
public void onResponse(NodeResponse nodeResponse) {
synchronized (responses) {
responses.add(nodeResponse);
}
}

@Override
public void onFailure(Exception e) {
if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) {
return;
}

logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, node), e);
synchronized (exceptions) {
exceptions.add(new FailedNodeException(node.getId(), "Failed node [" + node.getId() + "]", e));
}
}

for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); ++i) {
Object response = nodeResponseTracker.getResponse(i);
if (nodeResponseTracker.getResponse(i)instanceof FailedNodeException failedNodeException) {
failures.add(failedNodeException);
} else {
responses.add(nodeResponseClass.cast(response));
@Override
public String toString() {
return "[" + actionName + "][" + node.descriptionWithoutAttributes() + "]";
}
});

if (task instanceof CancellableTask) {
nodeCancellationListener.addListener(nodeResponseListener);
}

final var nodeRequest = newNodeRequest(request);
if (task != null) {
nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId());
}

transportService.sendRequest(
node,
transportNodeAction,
nodeRequest,
transportRequestOptions,
new ActionListenerResponseHandler<>(
ActionListener.releaseAfter(nodeResponseListener, refs.acquire()),
in -> newNodeResponse(in, node)
)
);
}
}

newResponseAsync(task, request, responses, failures, listener);
}

/**
Expand Down Expand Up @@ -199,141 +251,9 @@ protected void resolveRequest(NodesRequest request, ClusterState clusterState) {
request.setConcreteNodes(Arrays.stream(nodesIds).map(clusterState.nodes()::get).toArray(DiscoveryNode[]::new));
}

/**
* Get a backwards compatible transport action name
*/
protected String getTransportNodeAction(DiscoveryNode node) {
return transportNodeAction;
}

class AsyncAction implements CancellableTask.CancellationListener {

private final NodesRequest request;
private final ActionListener<NodesResponse> listener;
private final NodeResponseTracker nodeResponseTracker;
private final Task task;

AsyncAction(Task task, NodesRequest request, ActionListener<NodesResponse> listener) {
this.task = task;
this.request = request;
this.listener = listener;
if (request.concreteNodes() == null) {
resolveRequest(request, clusterService.state());
assert request.concreteNodes() != null;
}
this.nodeResponseTracker = new NodeResponseTracker(request.concreteNodes().length);
}

void start() {
if (task instanceof CancellableTask cancellableTask) {
cancellableTask.addListener(this);
}
final DiscoveryNode[] nodes = request.concreteNodes();
if (nodes.length == 0) {
finishHim();
return;
}
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout());
for (int i = 0; i < nodes.length; i++) {
final int idx = i;
final DiscoveryNode node = nodes[i];
final String nodeId = node.getId();
try {
TransportRequest nodeRequest = newNodeRequest(request);
if (task != null) {
nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId());
}

transportService.sendRequest(
node,
getTransportNodeAction(node),
nodeRequest,
transportRequestOptions,
new TransportResponseHandler<NodeResponse>() {
@Override
public NodeResponse read(StreamInput in) throws IOException {
return newNodeResponse(in, node);
}

@Override
public void handleResponse(NodeResponse response) {
onOperation(idx, response);
}

@Override
public void handleException(TransportException exp) {
onFailure(idx, node.getId(), exp);
}

@Override
public String toString() {
return "AsyncActionNodeResponseHandler{node=" + node + ", action=" + AsyncAction.this + '}';
}
}
);
} catch (Exception e) {
onFailure(idx, nodeId, e);
}
}
}

// For testing purposes
NodeResponseTracker getNodeResponseTracker() {
return nodeResponseTracker;
}

private void onOperation(int idx, NodeResponse nodeResponse) {
if (nodeResponseTracker.trackResponseAndCheckIfLast(idx, nodeResponse)) {
finishHim();
}
}

private void onFailure(int idx, String nodeId, Throwable t) {
logger.debug(() -> "failed to execute on node [" + nodeId + "]", t);
if (nodeResponseTracker.trackResponseAndCheckIfLast(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t))) {
finishHim();
}
}

private void finishHim() {
if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) {
return;
}

final String executor = finalExecutor.equals(ThreadPool.Names.SAME) ? ThreadPool.Names.GENERIC : finalExecutor;
threadPool.executor(executor).execute(() -> {
try {
newResponse(task, request, nodeResponseTracker, listener);
} catch (NodeResponseTracker.DiscardedResponsesException e) {
// We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take
// follow-up actions
listener.onFailure((Exception) e.getCause());
}
});
}

@Override
public void onCancelled() {
assert task instanceof CancellableTask : "task must be cancellable";
try {
((CancellableTask) task).ensureNotCancelled();
} catch (TaskCancelledException e) {
nodeResponseTracker.discardIntermediateResponses(e);
}
}

@Override
public String toString() {
return "AsyncAction{request=" + request + ", listener=" + listener + '}';
}
}

class NodeTransportHandler implements TransportRequestHandler<NodeRequest> {
@Override
public void messageReceived(NodeRequest request, TransportChannel channel, Task task) throws Exception {
if (task instanceof CancellableTask) {
((CancellableTask) task).ensureNotCancelled();
}
channel.sendResponse(nodeOperation(request, task));
}
}
Expand Down