diff --git a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java index 9e692697cbc85..5f805efe0c176 100644 --- a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java @@ -11,25 +11,26 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.action.support.NodeResponseTracker; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.concurrent.ListenableFuture; +import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportChannel; -import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; -import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import java.io.IOException; @@ -38,6 +39,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.core.Strings.format; + public abstract class TransportNodesAction< NodesRequest extends BaseNodesRequest, NodesResponse extends BaseNodesResponse, @@ -85,7 +88,7 @@ protected TransportNodesAction( this.nodeResponseClass = Objects.requireNonNull(nodeResponseClass); this.transportNodeAction = actionName + "[n]"; - this.finalExecutor = finalExecutor; + this.finalExecutor = finalExecutor.equals(ThreadPool.Names.SAME) ? ThreadPool.Names.GENERIC : finalExecutor; transportService.registerRequestHandler(transportNodeAction, nodeExecutor, nodeRequest, new NodeTransportHandler()); } @@ -123,40 +126,89 @@ protected TransportNodesAction( @Override protected void doExecute(Task task, NodesRequest request, ActionListener listener) { - new AsyncAction(task, request, listener).start(); - } + if (request.concreteNodes() == null) { + resolveRequest(request, clusterService.state()); + assert request.concreteNodes() != null; + } - /** - * Map the responses into {@code nodeResponseClass} responses and {@link FailedNodeException}s, convert to a {@link NodesResponse} and - * pass it to the listener. Fails the listener with a {@link NullPointerException} if {@code nodesResponses} is null. - * - * @param request The associated request. - * @param nodeResponseTracker All node-level responses collected so far - * @throws NodeResponseTracker.DiscardedResponsesException if {@code nodeResponseTracker} has already discarded the intermediate results - * @see #newResponseAsync(Task, BaseNodesRequest, List, List, ActionListener) - */ - // exposed for tests - void newResponse(Task task, NodesRequest request, NodeResponseTracker nodeResponseTracker, ActionListener listener) - throws NodeResponseTracker.DiscardedResponsesException { + final var responses = new ArrayList(request.concreteNodes().length); + final var exceptions = new ArrayList(0); - if (nodeResponseTracker == null) { - listener.onFailure(new NullPointerException("nodesResponses")); - return; + final var resultListener = new ListenableFuture(); + final var resultListenerCompleter = new RunOnce(() -> { + if (task instanceof CancellableTask cancellableTask) { + if (cancellableTask.notifyIfCancelled(resultListener)) { + return; + } + } + // ref releases all happen-before here so no need to be synchronized + threadPool.executor(finalExecutor) + .execute(ActionRunnable.wrap(resultListener, l -> newResponseAsync(task, request, responses, exceptions, l))); + }); + + final var nodeCancellationListener = new ListenableFuture(); // collects node listeners & completes them if cancelled + if (task instanceof CancellableTask cancellableTask) { + cancellableTask.addListener(() -> { + assert cancellableTask.isCancelled(); + resultListenerCompleter.run(); + cancellableTask.notifyIfCancelled(nodeCancellationListener); + }); } - final List responses = new ArrayList<>(); - final List failures = new ArrayList<>(); + final var transportRequestOptions = TransportRequestOptions.timeout(request.timeout()); + + try (var refs = new RefCountingRunnable(() -> { + resultListener.addListener(listener); + resultListenerCompleter.run(); + })) { + for (final var node : request.concreteNodes()) { + final ActionListener nodeResponseListener = ActionListener.notifyOnce(new ActionListener<>() { + @Override + public void onResponse(NodeResponse nodeResponse) { + synchronized (responses) { + responses.add(nodeResponse); + } + } + + @Override + public void onFailure(Exception e) { + if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) { + return; + } + + logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, node), e); + synchronized (exceptions) { + exceptions.add(new FailedNodeException(node.getId(), "Failed node [" + node.getId() + "]", e)); + } + } - for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); ++i) { - Object response = nodeResponseTracker.getResponse(i); - if (nodeResponseTracker.getResponse(i)instanceof FailedNodeException failedNodeException) { - failures.add(failedNodeException); - } else { - responses.add(nodeResponseClass.cast(response)); + @Override + public String toString() { + return "[" + actionName + "][" + node.descriptionWithoutAttributes() + "]"; + } + }); + + if (task instanceof CancellableTask) { + nodeCancellationListener.addListener(nodeResponseListener); + } + + final var nodeRequest = newNodeRequest(request); + if (task != null) { + nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId()); + } + + transportService.sendRequest( + node, + transportNodeAction, + nodeRequest, + transportRequestOptions, + new ActionListenerResponseHandler<>( + ActionListener.releaseAfter(nodeResponseListener, refs.acquire()), + in -> newNodeResponse(in, node) + ) + ); } } - - newResponseAsync(task, request, responses, failures, listener); } /** @@ -199,141 +251,9 @@ protected void resolveRequest(NodesRequest request, ClusterState clusterState) { request.setConcreteNodes(Arrays.stream(nodesIds).map(clusterState.nodes()::get).toArray(DiscoveryNode[]::new)); } - /** - * Get a backwards compatible transport action name - */ - protected String getTransportNodeAction(DiscoveryNode node) { - return transportNodeAction; - } - - class AsyncAction implements CancellableTask.CancellationListener { - - private final NodesRequest request; - private final ActionListener listener; - private final NodeResponseTracker nodeResponseTracker; - private final Task task; - - AsyncAction(Task task, NodesRequest request, ActionListener listener) { - this.task = task; - this.request = request; - this.listener = listener; - if (request.concreteNodes() == null) { - resolveRequest(request, clusterService.state()); - assert request.concreteNodes() != null; - } - this.nodeResponseTracker = new NodeResponseTracker(request.concreteNodes().length); - } - - void start() { - if (task instanceof CancellableTask cancellableTask) { - cancellableTask.addListener(this); - } - final DiscoveryNode[] nodes = request.concreteNodes(); - if (nodes.length == 0) { - finishHim(); - return; - } - final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout()); - for (int i = 0; i < nodes.length; i++) { - final int idx = i; - final DiscoveryNode node = nodes[i]; - final String nodeId = node.getId(); - try { - TransportRequest nodeRequest = newNodeRequest(request); - if (task != null) { - nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId()); - } - - transportService.sendRequest( - node, - getTransportNodeAction(node), - nodeRequest, - transportRequestOptions, - new TransportResponseHandler() { - @Override - public NodeResponse read(StreamInput in) throws IOException { - return newNodeResponse(in, node); - } - - @Override - public void handleResponse(NodeResponse response) { - onOperation(idx, response); - } - - @Override - public void handleException(TransportException exp) { - onFailure(idx, node.getId(), exp); - } - - @Override - public String toString() { - return "AsyncActionNodeResponseHandler{node=" + node + ", action=" + AsyncAction.this + '}'; - } - } - ); - } catch (Exception e) { - onFailure(idx, nodeId, e); - } - } - } - - // For testing purposes - NodeResponseTracker getNodeResponseTracker() { - return nodeResponseTracker; - } - - private void onOperation(int idx, NodeResponse nodeResponse) { - if (nodeResponseTracker.trackResponseAndCheckIfLast(idx, nodeResponse)) { - finishHim(); - } - } - - private void onFailure(int idx, String nodeId, Throwable t) { - logger.debug(() -> "failed to execute on node [" + nodeId + "]", t); - if (nodeResponseTracker.trackResponseAndCheckIfLast(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t))) { - finishHim(); - } - } - - private void finishHim() { - if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) { - return; - } - - final String executor = finalExecutor.equals(ThreadPool.Names.SAME) ? ThreadPool.Names.GENERIC : finalExecutor; - threadPool.executor(executor).execute(() -> { - try { - newResponse(task, request, nodeResponseTracker, listener); - } catch (NodeResponseTracker.DiscardedResponsesException e) { - // We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take - // follow-up actions - listener.onFailure((Exception) e.getCause()); - } - }); - } - - @Override - public void onCancelled() { - assert task instanceof CancellableTask : "task must be cancellable"; - try { - ((CancellableTask) task).ensureNotCancelled(); - } catch (TaskCancelledException e) { - nodeResponseTracker.discardIntermediateResponses(e); - } - } - - @Override - public String toString() { - return "AsyncAction{request=" + request + ", listener=" + listener + '}'; - } - } - class NodeTransportHandler implements TransportRequestHandler { @Override public void messageReceived(NodeRequest request, TransportChannel channel, Task task) throws Exception { - if (task instanceof CancellableTask) { - ((CancellableTask) task).ensureNotCancelled(); - } channel.sendResponse(nodeOperation(request, task)); } }