diff --git a/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java b/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java
new file mode 100644
index 0000000000000..7c6f022eb7e83
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java
@@ -0,0 +1,159 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.util.concurrent.RunOnce;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+
+import java.util.Iterator;
+
+/**
+ * Allows an action to fan-out to several sub-actions and accumulate their results, but which reacts to a cancellation by releasing all
+ * references to itself, and hence the partially-accumulated results, allowing them to be garbage-collected. This is a useful protection for
+ * cases where the results may consume a lot of heap (e.g. stats) but the final response may be delayed by a single slow node for long
+ * enough that the client gives up.
+ *
+ * Note that it's easy to accidentally capture another reference to this class when implementing it, and this will prevent the early release
+ * of any accumulated results. Beware of lambdas and method references. You must test your implementation carefully (using e.g.
+ * {@code ReachabilityChecker}) to make sure it doesn't do this.
+ */
+public abstract class CancellableFanOut- {
+
+ private static final Logger logger = LogManager.getLogger(CancellableFanOut.class);
+
+ /**
+ * Run the fan-out action.
+ *
+ * @param task The task to watch for cancellations. If {@code null} or not a {@link CancellableTask} then the fan-out still
+ * works, just without any cancellation handling.
+ * @param itemsIterator The items over which to fan out. Iterated on the calling thread.
+ * @param listener A listener for the final response, which is completed after all the fanned-out actions have completed. It is not
+ * completed promptly on cancellation. Completed on the thread that handles the final per-item response (or
+ * the calling thread if there are no items).
+ */
+ public final void run(@Nullable Task task, Iterator
- itemsIterator, ActionListener listener) {
+
+ final var cancellableTask = task instanceof CancellableTask ct ? ct : null;
+
+ // Captures the final result as soon as it's known (either on completion or on cancellation) without necessarily completing the
+ // outer listener, because we do not want to complete the outer listener until all sub-tasks are complete
+ final var resultListener = new SubscribableListener();
+
+ // Completes resultListener (either on completion or on cancellation). Captures a reference to 'this', but within a 'RunOnce' so it
+ // is released promptly when executed.
+ final var resultListenerCompleter = new RunOnce(() -> {
+ if (cancellableTask != null && cancellableTask.notifyIfCancelled(resultListener)) {
+ return;
+ }
+ // It's important that we complete resultListener before returning, because otherwise there's a risk that a cancellation arrives
+ // later which might unexpectedly complete the final listener on a transport thread.
+ ActionListener.completeWith(resultListener, this::onCompletion);
+ });
+
+ // Collects the per-item listeners up so they can all be completed exceptionally on cancellation. Never completed successfully.
+ final var itemCancellationListener = new SubscribableListener();
+ if (cancellableTask != null) {
+ cancellableTask.addListener(() -> {
+ assert cancellableTask.isCancelled();
+ resultListenerCompleter.run();
+ cancellableTask.notifyIfCancelled(itemCancellationListener);
+ });
+ }
+
+ try (var refs = new RefCountingRunnable(() -> {
+ // When all sub-tasks are complete, pass the result from resultListener to the outer listener.
+ resultListenerCompleter.run();
+ // resultListener is always complete by this point, so the outer listener is completed on this thread
+ resultListener.addListener(listener);
+ })) {
+ while (itemsIterator.hasNext()) {
+ final var item = itemsIterator.next();
+
+ // Captures a reference to 'this', but within a 'notifyOnce' so it is released promptly when completed.
+ final ActionListener itemResponseListener = ActionListener.notifyOnce(new ActionListener<>() {
+ @Override
+ public void onResponse(ItemResponse itemResponse) {
+ onItemResponse(item, itemResponse);
+ }
+
+ @Override
+ public void onFailure(Exception e) {
+ if (cancellableTask != null && cancellableTask.isCancelled()) {
+ // Completed on cancellation so it is released promptly, but there's no need to handle the exception.
+ return;
+ }
+ onItemFailure(item, e);
+ }
+
+ @Override
+ public String toString() {
+ return "[" + CancellableFanOut.this + "][" + item + "]";
+ }
+ });
+
+ if (cancellableTask != null) {
+ if (cancellableTask.isCancelled()) {
+ return;
+ }
+
+ // Register this item's listener for prompt cancellation notification.
+ itemCancellationListener.addListener(itemResponseListener);
+ }
+
+ // Process the item, capturing a ref to make sure the outer listener is completed after this item is processed.
+ sendItemRequest(item, ActionListener.releaseAfter(itemResponseListener, refs.acquire()));
+ }
+ } catch (Exception e) {
+ // NB the listener may have been completed already (by exiting this try block) so this exception may not be sent to the caller,
+ // but we cannot do anything else with it; an exception here is a bug anyway.
+ logger.error("unexpected failure in [" + this + "]", e);
+ assert false : e;
+ throw e;
+ }
+ }
+
+ /**
+ * Run the action (typically by sending a transport request) for an individual item. Called in sequence on the thread that invoked
+ * {@link #run}. May not be called for every item if the task is cancelled during the iteration.
+ *
+ * Note that it's easy to accidentally capture another reference to this class when implementing this method, and that will prevent the
+ * early release of any accumulated results. Beware of lambdas, and test carefully.
+ */
+ protected abstract void sendItemRequest(Item item, ActionListener listener);
+
+ /**
+ * Handle a successful response for an item. May be called concurrently for multiple items. Not called if the task is cancelled.
+ *
+ * Note that it's easy to accidentally capture another reference to this class when implementing this method, and that will prevent the
+ * early release of any accumulated results. Beware of lambdas, and test carefully.
+ */
+ protected abstract void onItemResponse(Item item, ItemResponse itemResponse);
+
+ /**
+ * Handle a failure for an item. May be called concurrently for multiple items. Not called if the task is cancelled.
+ *
+ * Note that it's easy to accidentally capture another reference to this class when implementing this method, and that will prevent the
+ * early release of any accumulated results. Beware of lambdas, and test carefully.
+ */
+ protected abstract void onItemFailure(Item item, Exception e);
+
+ /**
+ * Called when responses for all items have been processed, on the thread that processed the last per-item response. Not called if the
+ * task is cancelled.
+ *
+ * Note that it's easy to accidentally capture another reference to this class when implementing this method, and that will prevent the
+ * early release of any accumulated results. Beware of lambdas, and test carefully.
+ */
+ protected abstract FinalResponse onCompletion() throws Exception;
+}
diff --git a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java
index f7c4fad29fdfa..aec75e3300481 100644
--- a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java
+++ b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java
@@ -18,11 +18,11 @@
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.CancellableFanOut;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.DefaultShardOperationFailedException;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.IndicesOptions;
-import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.action.support.broadcast.BaseBroadcastResponse;
import org.elasticsearch.action.support.broadcast.BroadcastRequest;
@@ -37,9 +37,6 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.util.concurrent.ListenableFuture;
-import org.elasticsearch.common.util.concurrent.RunOnce;
-import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportChannel;
@@ -280,100 +277,18 @@ private void executeAsCoordinatingNode(
ResponseFactory responseFactory,
ActionListener listener
) {
- final var mutex = new Object();
- final var shardResponses = new ArrayList(availableShardCount);
- final var exceptions = new ArrayList(0);
- final var totalShards = new AtomicInteger(unavailableShardCount);
- final var successfulShards = new AtomicInteger(0);
-
- final var resultListener = new ListenableFuture();
- final var resultListenerCompleter = new RunOnce(() -> {
- if (task instanceof CancellableTask cancellableTask) {
- if (cancellableTask.notifyIfCancelled(resultListener)) {
- return;
- }
- }
- // ref releases all happen-before here so no need to be synchronized
- resultListener.onResponse(
- responseFactory.newResponse(totalShards.get(), successfulShards.get(), exceptions.size(), shardResponses, exceptions)
- );
- });
-
- final var nodeFailureListeners = new ListenableFuture();
- if (task instanceof CancellableTask cancellableTask) {
- cancellableTask.addListener(() -> {
- assert cancellableTask.isCancelled();
- resultListenerCompleter.run();
- cancellableTask.notifyIfCancelled(nodeFailureListeners);
- });
- }
-
- final var transportRequestOptions = TransportRequestOptions.timeout(request.timeout());
-
- try (var refs = new RefCountingRunnable(() -> {
- resultListener.addListener(listener);
- resultListenerCompleter.run();
- })) {
- for (final var entry : shardsByNodeId.entrySet()) {
+ new CancellableFanOut>, NodeResponse, Response>() {
+ final ArrayList shardResponses = new ArrayList<>(availableShardCount);
+ final ArrayList exceptions = new ArrayList<>(0);
+ final AtomicInteger totalShards = new AtomicInteger(unavailableShardCount);
+ final AtomicInteger successfulShards = new AtomicInteger(0);
+ final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout());
+
+ @Override
+ protected void sendItemRequest(Map.Entry> entry, ActionListener listener) {
final var node = nodes.get(entry.getKey());
final var shards = entry.getValue();
- final ActionListener nodeResponseListener = ActionListener.notifyOnce(new ActionListener() {
- @Override
- public void onResponse(NodeResponse nodeResponse) {
- synchronized (mutex) {
- shardResponses.addAll(nodeResponse.getResults());
- }
- totalShards.addAndGet(nodeResponse.getTotalShards());
- successfulShards.addAndGet(nodeResponse.getSuccessfulShards());
-
- for (BroadcastShardOperationFailedException exception : nodeResponse.getExceptions()) {
- if (TransportActions.isShardNotAvailableException(exception)) {
- assert node.getVersion().before(Version.V_8_7_0) : node; // we stopped sending these ignored exceptions
- } else {
- synchronized (mutex) {
- exceptions.add(
- new DefaultShardOperationFailedException(
- exception.getShardId().getIndexName(),
- exception.getShardId().getId(),
- exception
- )
- );
- }
- }
- }
- }
-
- @Override
- public void onFailure(Exception e) {
- if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) {
- return;
- }
-
- logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, node), e);
-
- final var failedNodeException = new FailedNodeException(node.getId(), "Failed node [" + node.getId() + "]", e);
- synchronized (mutex) {
- for (ShardRouting shard : shards) {
- exceptions.add(
- new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), failedNodeException)
- );
- }
- }
-
- totalShards.addAndGet(shards.size());
- }
-
- @Override
- public String toString() {
- return "[" + actionName + "][" + node.descriptionWithoutAttributes() + "]";
- }
- });
-
- if (task instanceof CancellableTask) {
- nodeFailureListeners.addListener(nodeResponseListener);
- }
-
final var nodeRequest = new NodeRequest(request, shards, node.getId());
if (task != null) {
nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId());
@@ -384,15 +299,74 @@ public String toString() {
transportNodeBroadcastAction,
nodeRequest,
transportRequestOptions,
- new ActionListenerResponseHandler<>(
- ActionListener.releaseAfter(nodeResponseListener, refs.acquire()),
- NodeResponse::new
- )
+ new ActionListenerResponseHandler<>(listener, nodeResponseReader)
);
}
- }
+
+ @Override
+ protected void onItemResponse(Map.Entry> entry, NodeResponse nodeResponse) {
+ final var node = nodes.get(entry.getKey());
+ synchronized (this) {
+ shardResponses.addAll(nodeResponse.getResults());
+ }
+ totalShards.addAndGet(nodeResponse.getTotalShards());
+ successfulShards.addAndGet(nodeResponse.getSuccessfulShards());
+
+ for (BroadcastShardOperationFailedException exception : nodeResponse.getExceptions()) {
+ if (TransportActions.isShardNotAvailableException(exception)) {
+ assert node.getVersion().before(Version.V_8_7_0) : node; // we stopped sending these ignored exceptions
+ } else {
+ synchronized (this) {
+ exceptions.add(
+ new DefaultShardOperationFailedException(
+ exception.getShardId().getIndexName(),
+ exception.getShardId().getId(),
+ exception
+ )
+ );
+ }
+ }
+ }
+ }
+
+ @Override
+ protected void onItemFailure(Map.Entry> entry, Exception e) {
+ final var node = nodes.get(entry.getKey());
+ final var shards = entry.getValue();
+ logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, node), e);
+
+ final var failedNodeException = new FailedNodeException(node.getId(), "Failed node [" + node.getId() + "]", e);
+ synchronized (this) {
+ for (ShardRouting shard : shards) {
+ exceptions.add(new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), failedNodeException));
+ }
+ }
+
+ totalShards.addAndGet(shards.size());
+ }
+
+ @Override
+ protected Response onCompletion() {
+ // ref releases all happen-before here so no need to be synchronized
+ return responseFactory.newResponse(
+ totalShards.get(),
+ successfulShards.get(),
+ exceptions.size(),
+ shardResponses,
+ exceptions
+ );
+ }
+
+ @Override
+ public String toString() {
+ return actionName;
+ }
+ }.run(task, shardsByNodeId.entrySet().iterator(), listener);
}
+ // not an inline method reference to avoid capturing CancellableFanOut.this.
+ private final Writeable.Reader nodeResponseReader = NodeResponse::new;
+
class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler {
@Override
public void messageReceived(final NodeRequest request, TransportChannel channel, Task task) throws Exception {
@@ -415,87 +389,51 @@ private void executeAsDataNode(
) {
logger.trace("[{}] executing operation on [{}] shards", actionName, shards.size());
- final var results = new ArrayList(shards.size());
- final var exceptions = new ArrayList(0);
+ new CancellableFanOut() {
- final var resultListener = new ListenableFuture();
- final var resultListenerCompleter = new RunOnce(() -> {
- if (task instanceof CancellableTask cancellableTask) {
- if (cancellableTask.notifyIfCancelled(resultListener)) {
- return;
- }
+ final ArrayList results = new ArrayList<>(shards.size());
+ final ArrayList exceptions = new ArrayList<>(0);
+
+ @Override
+ protected void sendItemRequest(ShardRouting shardRouting, ActionListener listener) {
+ logger.trace(() -> format("[%s] executing operation for shard [%s]", actionName, shardRouting.shortSummary()));
+ ActionRunnable.wrap(listener, l -> shardOperation(request, shardRouting, task, l)).run();
}
- // ref releases all happen-before here so no need to be synchronized
- resultListener.onResponse(new NodeResponse(nodeId, shards.size(), results, exceptions));
- });
-
- final var shardFailureListeners = new ListenableFuture();
- if (task instanceof CancellableTask cancellableTask) {
- cancellableTask.addListener(() -> {
- assert cancellableTask.isCancelled();
- resultListenerCompleter.run();
- cancellableTask.notifyIfCancelled(shardFailureListeners);
- });
- }
- try (var refs = new RefCountingRunnable(() -> {
- resultListener.addListener(listener);
- resultListenerCompleter.run();
- })) {
- for (final var shardRouting : shards) {
- if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) {
- return;
+ @Override
+ protected void onItemResponse(ShardRouting shardRouting, ShardOperationResult shardOperationResult) {
+ synchronized (results) {
+ results.add(shardOperationResult);
}
+ }
- final ActionListener shardListener = ActionListener.notifyOnce(new ActionListener<>() {
- @Override
- public void onResponse(ShardOperationResult shardOperationResult) {
- logger.trace(() -> format("[%s] completed operation for shard [%s]", actionName, shardRouting.shortSummary()));
- synchronized (results) {
- results.add(shardOperationResult);
- }
- }
-
- @Override
- public void onFailure(Exception e) {
- if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) {
- return;
- }
- logger.log(
- TransportActions.isShardNotAvailableException(e) ? Level.TRACE : Level.DEBUG,
- () -> format("[%s] failed to execute operation for shard [%s]", actionName, shardRouting.shortSummary()),
- e
+ @Override
+ protected void onItemFailure(ShardRouting shardRouting, Exception e) {
+ logger.log(
+ TransportActions.isShardNotAvailableException(e) ? Level.TRACE : Level.DEBUG,
+ () -> format("[%s] failed to execute operation for shard [%s]", actionName, shardRouting.shortSummary()),
+ e
+ );
+ if (TransportActions.isShardNotAvailableException(e) == false) {
+ synchronized (exceptions) {
+ exceptions.add(
+ new BroadcastShardOperationFailedException(shardRouting.shardId(), "operation " + actionName + " failed", e)
);
- if (TransportActions.isShardNotAvailableException(e) == false) {
- synchronized (exceptions) {
- exceptions.add(
- new BroadcastShardOperationFailedException(
- shardRouting.shardId(),
- "operation " + actionName + " failed",
- e
- )
- );
- }
- }
}
-
- @Override
- public String toString() {
- return "[" + actionName + "][" + shardRouting + "]";
- }
- });
-
- if (task instanceof CancellableTask) {
- shardFailureListeners.addListener(shardListener);
}
+ }
- logger.trace(() -> format("[%s] executing operation for shard [%s]", actionName, shardRouting.shortSummary()));
- ActionRunnable.wrap(
- ActionListener.releaseAfter(shardListener, refs.acquire()),
- l -> shardOperation(request, shardRouting, task, l)
- ).run();
+ @Override
+ protected NodeResponse onCompletion() {
+ // ref releases all happen-before here so no need to be synchronized
+ return new NodeResponse(nodeId, shards.size(), results, exceptions);
}
- }
+
+ @Override
+ public String toString() {
+ return actionName;
+ }
+ }.run(task, shards.iterator(), listener);
}
class NodeRequest extends TransportRequest implements IndicesRequest {
diff --git a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java
index 5f805efe0c176..fedd357501ac1 100644
--- a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java
+++ b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java
@@ -15,16 +15,15 @@
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.CancellableFanOut;
import org.elasticsearch.action.support.HandledTransportAction;
-import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.util.concurrent.ListenableFuture;
-import org.elasticsearch.common.util.concurrent.RunOnce;
-import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel;
@@ -131,84 +130,64 @@ protected void doExecute(Task task, NodesRequest request, ActionListener(request.concreteNodes().length);
- final var exceptions = new ArrayList(0);
+ new CancellableFanOut, Exception>>() {
- final var resultListener = new ListenableFuture();
- final var resultListenerCompleter = new RunOnce(() -> {
- if (task instanceof CancellableTask cancellableTask) {
- if (cancellableTask.notifyIfCancelled(resultListener)) {
- return;
- }
- }
- // ref releases all happen-before here so no need to be synchronized
- threadPool.executor(finalExecutor)
- .execute(ActionRunnable.wrap(resultListener, l -> newResponseAsync(task, request, responses, exceptions, l)));
- });
-
- final var nodeCancellationListener = new ListenableFuture(); // collects node listeners & completes them if cancelled
- if (task instanceof CancellableTask cancellableTask) {
- cancellableTask.addListener(() -> {
- assert cancellableTask.isCancelled();
- resultListenerCompleter.run();
- cancellableTask.notifyIfCancelled(nodeCancellationListener);
- });
- }
-
- final var transportRequestOptions = TransportRequestOptions.timeout(request.timeout());
-
- try (var refs = new RefCountingRunnable(() -> {
- resultListener.addListener(listener);
- resultListenerCompleter.run();
- })) {
- for (final var node : request.concreteNodes()) {
- final ActionListener nodeResponseListener = ActionListener.notifyOnce(new ActionListener<>() {
- @Override
- public void onResponse(NodeResponse nodeResponse) {
- synchronized (responses) {
- responses.add(nodeResponse);
- }
- }
+ final ArrayList responses = new ArrayList<>(request.concreteNodes().length);
+ final ArrayList exceptions = new ArrayList<>(0);
- @Override
- public void onFailure(Exception e) {
- if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) {
- return;
- }
-
- logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, node), e);
- synchronized (exceptions) {
- exceptions.add(new FailedNodeException(node.getId(), "Failed node [" + node.getId() + "]", e));
- }
- }
-
- @Override
- public String toString() {
- return "[" + actionName + "][" + node.descriptionWithoutAttributes() + "]";
- }
- });
-
- if (task instanceof CancellableTask) {
- nodeCancellationListener.addListener(nodeResponseListener);
- }
+ final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout());
+ @Override
+ protected void sendItemRequest(DiscoveryNode discoveryNode, ActionListener listener) {
final var nodeRequest = newNodeRequest(request);
if (task != null) {
nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId());
}
transportService.sendRequest(
- node,
+ discoveryNode,
transportNodeAction,
nodeRequest,
transportRequestOptions,
- new ActionListenerResponseHandler<>(
- ActionListener.releaseAfter(nodeResponseListener, refs.acquire()),
- in -> newNodeResponse(in, node)
- )
+ new ActionListenerResponseHandler<>(listener, nodeResponseReader(discoveryNode))
);
}
- }
+
+ @Override
+ protected void onItemResponse(DiscoveryNode discoveryNode, NodeResponse nodeResponse) {
+ synchronized (responses) {
+ responses.add(nodeResponse);
+ }
+ }
+
+ @Override
+ protected void onItemFailure(DiscoveryNode discoveryNode, Exception e) {
+ logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, discoveryNode), e);
+ synchronized (exceptions) {
+ exceptions.add(new FailedNodeException(discoveryNode.getId(), "Failed node [" + discoveryNode.getId() + "]", e));
+ }
+ }
+
+ @Override
+ protected CheckedConsumer, Exception> onCompletion() {
+ // ref releases all happen-before here so no need to be synchronized
+ return l -> newResponseAsync(task, request, responses, exceptions, l);
+ }
+
+ @Override
+ public String toString() {
+ return actionName;
+ }
+ }.run(
+ task,
+ Iterators.forArray(request.concreteNodes()),
+ listener.delegateFailure((l, r) -> threadPool.executor(finalExecutor).execute(ActionRunnable.wrap(l, r)))
+ );
+ }
+
+ private Writeable.Reader nodeResponseReader(DiscoveryNode discoveryNode) {
+ // not an inline lambda to avoid capturing CancellableFanOut.this.
+ return in -> TransportNodesAction.this.newNodeResponse(in, discoveryNode);
}
/**
diff --git a/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java b/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java
new file mode 100644
index 0000000000000..db48b09e95a08
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java
@@ -0,0 +1,133 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskCancelHelper;
+import org.elasticsearch.tasks.TaskCancelledException;
+import org.elasticsearch.tasks.TaskId;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.ReachabilityChecker;
+import org.hamcrest.Matchers;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class CancellableFanOutTests extends ESTestCase {
+
+ public void testFanOutWithoutCancellation() {
+ final var task = randomFrom(
+ new Task(1, "test", "test", "", TaskId.EMPTY_TASK_ID, Map.of()),
+ new CancellableTask(1, "test", "test", "", TaskId.EMPTY_TASK_ID, Map.of()),
+ null
+ );
+ final var future = new PlainActionFuture();
+
+ final var itemListeners = new HashMap>();
+ final var finalFailure = randomBoolean();
+
+ new CancellableFanOut() {
+ int counter;
+
+ @Override
+ protected void sendItemRequest(String item, ActionListener listener) {
+ itemListeners.put(item, listener);
+ }
+
+ @Override
+ protected void onItemResponse(String item, String itemResponse) {
+ assertThat(item, Matchers.oneOf("a", "c"));
+ assertEquals(item + "-response", itemResponse);
+ counter += 1;
+ }
+
+ @Override
+ protected void onItemFailure(String item, Exception e) {
+ assertEquals("b", item);
+ counter += 1;
+ }
+
+ @Override
+ protected String onCompletion() {
+ assertEquals(3, counter);
+ if (finalFailure) {
+ throw new ElasticsearchException("failed");
+ } else {
+ return "completed";
+ }
+ }
+ }.run(task, List.of("a", "b", "c").iterator(), future);
+
+ itemListeners.remove("a").onResponse("a-response");
+ assertFalse(future.isDone());
+ itemListeners.remove("b").onFailure(new ElasticsearchException("b-response"));
+ assertFalse(future.isDone());
+ itemListeners.remove("c").onResponse("c-response");
+ assertTrue(future.isDone());
+ if (finalFailure) {
+ assertEquals("failed", expectThrows(ElasticsearchException.class, future::actionGet).getMessage());
+ } else {
+ assertEquals("completed", future.actionGet());
+ }
+ }
+
+ public void testReleaseOnCancellation() {
+ final var task = new CancellableTask(1, "test", "test", "", TaskId.EMPTY_TASK_ID, Map.of());
+ final var future = new PlainActionFuture();
+
+ final var itemListeners = new HashMap>();
+ final var handledItemResponse = new AtomicBoolean();
+
+ final var reachabilityChecker = new ReachabilityChecker();
+ reachabilityChecker.register(new CancellableFanOut() {
+ @Override
+ protected void sendItemRequest(String item, ActionListener listener) {
+ itemListeners.put(item, listener);
+ }
+
+ @Override
+ protected void onItemResponse(String item, String itemResponse) {
+ assertEquals("a", item);
+ assertEquals("a-response", itemResponse);
+ assertTrue(handledItemResponse.compareAndSet(false, true));
+ }
+
+ @Override
+ protected void onItemFailure(String item, Exception e) {
+ fail(item);
+ }
+
+ @Override
+ protected String onCompletion() {
+ throw new AssertionError("onCompletion");
+ }
+ }).run(task, List.of("a", "b", "c").iterator(), future);
+
+ itemListeners.remove("a").onResponse("a-response");
+ assertTrue(handledItemResponse.get());
+ reachabilityChecker.checkReachable();
+
+ TaskCancelHelper.cancel(task, "test");
+ reachabilityChecker.ensureUnreachable(); // even though we're still holding on to some item listeners.
+ assertFalse(future.isDone());
+
+ itemListeners.remove("b").onResponse("b-response");
+ assertFalse(future.isDone());
+
+ itemListeners.remove("c").onFailure(new ElasticsearchException("c-response"));
+ assertTrue(itemListeners.isEmpty());
+ assertTrue(future.isDone());
+ expectThrows(TaskCancelledException.class, future::actionGet);
+ }
+}