From 19bb31a0e58a2637619129ed746c411320bd14e5 Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 26 May 2023 10:02:39 +0100 Subject: [PATCH 1/3] Introduce CancellableFanOut We have this somewhat-complex pattern in 3 places already, and #96279 will introduce a couple more, so this commit extracts it as a dedicated utility. Relates #92987 Relates #93484 --- .../action/support/CancellableFanOut.java | 161 ++++++++++ .../node/TransportBroadcastByNodeAction.java | 275 +++++++----------- .../support/nodes/TransportNodesAction.java | 111 +++---- .../support/CancellableFanOutTests.java | 124 ++++++++ 4 files changed, 433 insertions(+), 238 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java create mode 100644 server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java diff --git a/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java b/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java new file mode 100644 index 0000000000000..dfe4ebe24ae52 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java @@ -0,0 +1,161 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.support; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.RunOnce; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; + +import java.util.Iterator; + +/** + * Allows an action to fan-out to several sub-actions and accumulate their results, but which reacts to a cancellation by releasing all + * references to itself, and hence the partially-accumulated results, allowing them to be garbage-collected. This is a useful protection for + * cases where the results may consume a lot of heap (e.g. stats) but the final response may be delayed by a single slow node for long + * enough that the client gives up. + *

+ * Note that it's easy to accidentally capture another reference to this class when implementing it, and this will prevent the early release + * of any accumulated results. Beware of lambdas and method references. You must test your implementation carefully (using e.g. + * {@code ReachabilityChecker}) to make sure it doesn't do this. + */ +public abstract class CancellableFanOut { + + private static final Logger logger = LogManager.getLogger(CancellableFanOut.class); + + /** + * Run the fan-out action. + * + * @param task The task to watch for cancellations. If {@code null} or not a {@link CancellableTask} then the fan-out still + * works, just without any cancellation handling. + * @param itemsIterator The items over which to fan out. Iterated on the calling thread. + * @param listener A listener for the final response, which is completed after all the fanned-out actions have completed. It is not + * completed promptly on cancellation. Completed on the thread that handles the final per-item response (or + * the calling thread if there are no items). + */ + public final void run(@Nullable Task task, Iterator itemsIterator, ActionListener listener) { + + final var cancellableTask = task instanceof CancellableTask ct ? ct : null; + + // Captures the final result as soon as it's known (either on completion or on cancellation) without necessarily completing the + // outer listener, because we do not want to complete the outer listener until all sub-tasks are complete + final var resultListener = new SubscribableListener(); + + // Completes resultListener (either on completion or on cancellation). Captures a reference to 'this', but within a 'RunOnce' so it + // is released promptly when executed. + final var resultListenerCompleter = new RunOnce(() -> { + if (cancellableTask != null && cancellableTask.notifyIfCancelled(resultListener)) { + return; + } + onCompletion(resultListener); + + // It's important that onCompletion() completes resultListener before returning, because otherwise there's a risk that + // a cancellation arrives later which might unexpectedly complete the final listener on a transport thread. + assert resultListener.isDone() : "onCompletion did not complete its listener"; + }); + + // Collects the per-item listeners up so they can all be completed exceptionally on cancellation. Never completed successfully. + final var itemCancellationListener = new SubscribableListener(); + if (cancellableTask != null) { + cancellableTask.addListener(() -> { + assert cancellableTask.isCancelled(); + resultListenerCompleter.run(); + cancellableTask.notifyIfCancelled(itemCancellationListener); + }); + } + + try (var refs = new RefCountingRunnable(() -> { + // When all sub-tasks are complete, pass the result from resultListener to the outer listener. + resultListenerCompleter.run(); + // resultListener is always complete by this point, so the outer listener is completed on this thread + resultListener.addListener(listener); + })) { + while (itemsIterator.hasNext()) { + final var item = itemsIterator.next(); + + // Captures a reference to 'this', but within a 'notifyOnce' so it is released promptly when completed. + final ActionListener itemResponseListener = ActionListener.notifyOnce(new ActionListener<>() { + @Override + public void onResponse(ItemResponse itemResponse) { + onItemResponse(item, itemResponse); + } + + @Override + public void onFailure(Exception e) { + if (cancellableTask != null && cancellableTask.isCancelled()) { + // Completed on cancellation so it is released promptly, but there's no need to handle the exception. + return; + } + onItemFailure(item, e); + } + + @Override + public String toString() { + return "[" + CancellableFanOut.this + "][" + item + "]"; + } + }); + + if (cancellableTask != null) { + if (cancellableTask.isCancelled()) { + return; + } + + // Register this item's listener for prompt cancellation notification. + itemCancellationListener.addListener(itemResponseListener); + } + + // Process the item, capturing a ref to make sure the outer listener is completed after this item is processed. + sendItemRequest(item, ActionListener.releaseAfter(itemResponseListener, refs.acquire())); + } + } catch (Exception e) { + // NB the listener may have been completed already (by exiting this try block) so this exception may not be sent to the caller, + // but we cannot do anything else with it; an exception here is a bug anyway. + logger.error("unexpected failure in [" + this + "]", e); + assert false : e; + throw e; + } + } + + /** + * Run the action (typically by sending a transport request) for an individual item. Called in sequence on the thread that invoked + * {@link #run}. May not be called for every item if the task is cancelled during the iteration. + *

+ * Note that it's easy to accidentally capture another reference to this class when implementing this method, and that will prevent the + * early release of any accumulated results. Beware of lambdas, and test carefully. + */ + protected abstract void sendItemRequest(Item item, ActionListener listener); + + /** + * Handle a successful response for an item. May be called concurrently for multiple items. Not called if the task is cancelled. + *

+ * Note that it's easy to accidentally capture another reference to this class when implementing this method, and that will prevent the + * early release of any accumulated results. Beware of lambdas, and test carefully. + */ + protected abstract void onItemResponse(Item item, ItemResponse itemResponse); + + /** + * Handle a failure for an item. May be called concurrently for multiple items. Not called if the task is cancelled. + *

+ * Note that it's easy to accidentally capture another reference to this class when implementing this method, and that will prevent the + * early release of any accumulated results. Beware of lambdas, and test carefully. + */ + protected abstract void onItemFailure(Item item, Exception e); + + /** + * Called when responses for all items have been processed, on the thread that processed the last per-item response. Not called if the + * task is cancelled. Must complete the listener before returning. + *

+ * Note that it's easy to accidentally capture another reference to this class when implementing this method, and that will prevent the + * early release of any accumulated results. Beware of lambdas, and test carefully. + */ + protected abstract void onCompletion(ActionListener listener); +} diff --git a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java index f7c4fad29fdfa..25b1060461867 100644 --- a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java @@ -18,11 +18,11 @@ import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.CancellableFanOut; import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.DefaultShardOperationFailedException; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.IndicesOptions; -import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.action.support.broadcast.BaseBroadcastResponse; import org.elasticsearch.action.support.broadcast.BroadcastRequest; @@ -37,9 +37,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.concurrent.ListenableFuture; -import org.elasticsearch.common.util.concurrent.RunOnce; -import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.transport.TransportChannel; @@ -280,100 +277,18 @@ private void executeAsCoordinatingNode( ResponseFactory responseFactory, ActionListener listener ) { - final var mutex = new Object(); - final var shardResponses = new ArrayList(availableShardCount); - final var exceptions = new ArrayList(0); - final var totalShards = new AtomicInteger(unavailableShardCount); - final var successfulShards = new AtomicInteger(0); - - final var resultListener = new ListenableFuture(); - final var resultListenerCompleter = new RunOnce(() -> { - if (task instanceof CancellableTask cancellableTask) { - if (cancellableTask.notifyIfCancelled(resultListener)) { - return; - } - } - // ref releases all happen-before here so no need to be synchronized - resultListener.onResponse( - responseFactory.newResponse(totalShards.get(), successfulShards.get(), exceptions.size(), shardResponses, exceptions) - ); - }); - - final var nodeFailureListeners = new ListenableFuture(); - if (task instanceof CancellableTask cancellableTask) { - cancellableTask.addListener(() -> { - assert cancellableTask.isCancelled(); - resultListenerCompleter.run(); - cancellableTask.notifyIfCancelled(nodeFailureListeners); - }); - } - - final var transportRequestOptions = TransportRequestOptions.timeout(request.timeout()); - - try (var refs = new RefCountingRunnable(() -> { - resultListener.addListener(listener); - resultListenerCompleter.run(); - })) { - for (final var entry : shardsByNodeId.entrySet()) { + new CancellableFanOut>, NodeResponse, Response>() { + final ArrayList shardResponses = new ArrayList<>(availableShardCount); + final ArrayList exceptions = new ArrayList<>(0); + final AtomicInteger totalShards = new AtomicInteger(unavailableShardCount); + final AtomicInteger successfulShards = new AtomicInteger(0); + final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout()); + + @Override + protected void sendItemRequest(Map.Entry> entry, ActionListener listener) { final var node = nodes.get(entry.getKey()); final var shards = entry.getValue(); - final ActionListener nodeResponseListener = ActionListener.notifyOnce(new ActionListener() { - @Override - public void onResponse(NodeResponse nodeResponse) { - synchronized (mutex) { - shardResponses.addAll(nodeResponse.getResults()); - } - totalShards.addAndGet(nodeResponse.getTotalShards()); - successfulShards.addAndGet(nodeResponse.getSuccessfulShards()); - - for (BroadcastShardOperationFailedException exception : nodeResponse.getExceptions()) { - if (TransportActions.isShardNotAvailableException(exception)) { - assert node.getVersion().before(Version.V_8_7_0) : node; // we stopped sending these ignored exceptions - } else { - synchronized (mutex) { - exceptions.add( - new DefaultShardOperationFailedException( - exception.getShardId().getIndexName(), - exception.getShardId().getId(), - exception - ) - ); - } - } - } - } - - @Override - public void onFailure(Exception e) { - if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) { - return; - } - - logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, node), e); - - final var failedNodeException = new FailedNodeException(node.getId(), "Failed node [" + node.getId() + "]", e); - synchronized (mutex) { - for (ShardRouting shard : shards) { - exceptions.add( - new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), failedNodeException) - ); - } - } - - totalShards.addAndGet(shards.size()); - } - - @Override - public String toString() { - return "[" + actionName + "][" + node.descriptionWithoutAttributes() + "]"; - } - }); - - if (task instanceof CancellableTask) { - nodeFailureListeners.addListener(nodeResponseListener); - } - final var nodeRequest = new NodeRequest(request, shards, node.getId()); if (task != null) { nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId()); @@ -384,15 +299,70 @@ public String toString() { transportNodeBroadcastAction, nodeRequest, transportRequestOptions, - new ActionListenerResponseHandler<>( - ActionListener.releaseAfter(nodeResponseListener, refs.acquire()), - NodeResponse::new - ) + new ActionListenerResponseHandler<>(listener, nodeResponseReader) ); } - } + + @Override + protected void onItemResponse(Map.Entry> entry, NodeResponse nodeResponse) { + final var node = nodes.get(entry.getKey()); + synchronized (this) { + shardResponses.addAll(nodeResponse.getResults()); + } + totalShards.addAndGet(nodeResponse.getTotalShards()); + successfulShards.addAndGet(nodeResponse.getSuccessfulShards()); + + for (BroadcastShardOperationFailedException exception : nodeResponse.getExceptions()) { + if (TransportActions.isShardNotAvailableException(exception)) { + assert node.getVersion().before(Version.V_8_7_0) : node; // we stopped sending these ignored exceptions + } else { + synchronized (this) { + exceptions.add( + new DefaultShardOperationFailedException( + exception.getShardId().getIndexName(), + exception.getShardId().getId(), + exception + ) + ); + } + } + } + } + + @Override + protected void onItemFailure(Map.Entry> entry, Exception e) { + final var node = nodes.get(entry.getKey()); + final var shards = entry.getValue(); + logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, node), e); + + final var failedNodeException = new FailedNodeException(node.getId(), "Failed node [" + node.getId() + "]", e); + synchronized (this) { + for (ShardRouting shard : shards) { + exceptions.add(new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), failedNodeException)); + } + } + + totalShards.addAndGet(shards.size()); + } + + @Override + protected void onCompletion(ActionListener listener) { + // ref releases all happen-before here so no need to be synchronized + listener.onResponse( + responseFactory.newResponse(totalShards.get(), successfulShards.get(), exceptions.size(), shardResponses, exceptions) + ); + } + + @Override + public String toString() { + return actionName; + } + }.run(task, shardsByNodeId.entrySet().iterator(), listener); } + // not an inline method reference to avoid capturing CancellableFanOut.this. + private final Writeable.Reader nodeResponseReader = NodeResponse::new; + class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler { @Override public void messageReceived(final NodeRequest request, TransportChannel channel, Task task) throws Exception { @@ -415,87 +385,52 @@ private void executeAsDataNode( ) { logger.trace("[{}] executing operation on [{}] shards", actionName, shards.size()); - final var results = new ArrayList(shards.size()); - final var exceptions = new ArrayList(0); + new CancellableFanOut() { - final var resultListener = new ListenableFuture(); - final var resultListenerCompleter = new RunOnce(() -> { - if (task instanceof CancellableTask cancellableTask) { - if (cancellableTask.notifyIfCancelled(resultListener)) { - return; - } + final ArrayList results = new ArrayList<>(shards.size()); + final ArrayList exceptions = new ArrayList<>(0); + + @Override + protected void sendItemRequest(ShardRouting shardRouting, ActionListener listener) { + logger.trace(() -> format("[%s] executing operation for shard [%s]", actionName, shardRouting.shortSummary())); + ActionRunnable.wrap(listener, l -> shardOperation(request, shardRouting, task, l)).run(); } - // ref releases all happen-before here so no need to be synchronized - resultListener.onResponse(new NodeResponse(nodeId, shards.size(), results, exceptions)); - }); - - final var shardFailureListeners = new ListenableFuture(); - if (task instanceof CancellableTask cancellableTask) { - cancellableTask.addListener(() -> { - assert cancellableTask.isCancelled(); - resultListenerCompleter.run(); - cancellableTask.notifyIfCancelled(shardFailureListeners); - }); - } - try (var refs = new RefCountingRunnable(() -> { - resultListener.addListener(listener); - resultListenerCompleter.run(); - })) { - for (final var shardRouting : shards) { - if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) { - return; + @Override + protected void onItemResponse(ShardRouting shardRouting, ShardOperationResult shardOperationResult) { + synchronized (results) { + results.add(shardOperationResult); } + } - final ActionListener shardListener = ActionListener.notifyOnce(new ActionListener<>() { - @Override - public void onResponse(ShardOperationResult shardOperationResult) { - logger.trace(() -> format("[%s] completed operation for shard [%s]", actionName, shardRouting.shortSummary())); - synchronized (results) { - results.add(shardOperationResult); - } - } - - @Override - public void onFailure(Exception e) { - if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) { - return; - } - logger.log( - TransportActions.isShardNotAvailableException(e) ? Level.TRACE : Level.DEBUG, - () -> format("[%s] failed to execute operation for shard [%s]", actionName, shardRouting.shortSummary()), - e + @Override + protected void onItemFailure(ShardRouting shardRouting, Exception e) { + logger.log( + TransportActions.isShardNotAvailableException(e) ? Level.TRACE : Level.DEBUG, + () -> format("[%s] failed to execute operation for shard [%s]", actionName, shardRouting.shortSummary()), + e + ); + if (TransportActions.isShardNotAvailableException(e) == false) { + synchronized (exceptions) { + exceptions.add( + new BroadcastShardOperationFailedException(shardRouting.shardId(), "operation " + actionName + " failed", e) ); - if (TransportActions.isShardNotAvailableException(e) == false) { - synchronized (exceptions) { - exceptions.add( - new BroadcastShardOperationFailedException( - shardRouting.shardId(), - "operation " + actionName + " failed", - e - ) - ); - } - } } + } + } - @Override - public String toString() { - return "[" + actionName + "][" + shardRouting + "]"; - } - }); + @Override + protected void onCompletion(ActionListener listener) { + // ref releases all happen-before here so no need to be synchronized + listener.onResponse(new NodeResponse(nodeId, shards.size(), results, exceptions)); - if (task instanceof CancellableTask) { - shardFailureListeners.addListener(shardListener); - } + } - logger.trace(() -> format("[%s] executing operation for shard [%s]", actionName, shardRouting.shortSummary())); - ActionRunnable.wrap( - ActionListener.releaseAfter(shardListener, refs.acquire()), - l -> shardOperation(request, shardRouting, task, l) - ).run(); + @Override + public String toString() { + return actionName; } - } + }.run(task, shards.iterator(), listener); } class NodeRequest extends TransportRequest implements IndicesRequest { diff --git a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java index 5f805efe0c176..7568ad2268bd0 100644 --- a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java @@ -15,16 +15,14 @@ import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.CancellableFanOut; import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.concurrent.ListenableFuture; -import org.elasticsearch.common.util.concurrent.RunOnce; -import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportChannel; @@ -131,84 +129,61 @@ protected void doExecute(Task task, NodesRequest request, ActionListener(request.concreteNodes().length); - final var exceptions = new ArrayList(0); + new CancellableFanOut() { - final var resultListener = new ListenableFuture(); - final var resultListenerCompleter = new RunOnce(() -> { - if (task instanceof CancellableTask cancellableTask) { - if (cancellableTask.notifyIfCancelled(resultListener)) { - return; - } - } - // ref releases all happen-before here so no need to be synchronized - threadPool.executor(finalExecutor) - .execute(ActionRunnable.wrap(resultListener, l -> newResponseAsync(task, request, responses, exceptions, l))); - }); - - final var nodeCancellationListener = new ListenableFuture(); // collects node listeners & completes them if cancelled - if (task instanceof CancellableTask cancellableTask) { - cancellableTask.addListener(() -> { - assert cancellableTask.isCancelled(); - resultListenerCompleter.run(); - cancellableTask.notifyIfCancelled(nodeCancellationListener); - }); - } - - final var transportRequestOptions = TransportRequestOptions.timeout(request.timeout()); - - try (var refs = new RefCountingRunnable(() -> { - resultListener.addListener(listener); - resultListenerCompleter.run(); - })) { - for (final var node : request.concreteNodes()) { - final ActionListener nodeResponseListener = ActionListener.notifyOnce(new ActionListener<>() { - @Override - public void onResponse(NodeResponse nodeResponse) { - synchronized (responses) { - responses.add(nodeResponse); - } - } + final ArrayList responses = new ArrayList<>(request.concreteNodes().length); + final ArrayList exceptions = new ArrayList<>(0); - @Override - public void onFailure(Exception e) { - if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) { - return; - } - - logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, node), e); - synchronized (exceptions) { - exceptions.add(new FailedNodeException(node.getId(), "Failed node [" + node.getId() + "]", e)); - } - } - - @Override - public String toString() { - return "[" + actionName + "][" + node.descriptionWithoutAttributes() + "]"; - } - }); - - if (task instanceof CancellableTask) { - nodeCancellationListener.addListener(nodeResponseListener); - } + final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout()); + @Override + protected void sendItemRequest(DiscoveryNode discoveryNode, ActionListener listener) { final var nodeRequest = newNodeRequest(request); if (task != null) { nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId()); } transportService.sendRequest( - node, + discoveryNode, transportNodeAction, nodeRequest, transportRequestOptions, - new ActionListenerResponseHandler<>( - ActionListener.releaseAfter(nodeResponseListener, refs.acquire()), - in -> newNodeResponse(in, node) - ) + new ActionListenerResponseHandler<>(listener, nodeResponseReader(discoveryNode)) ); } - } + + @Override + protected void onItemResponse(DiscoveryNode discoveryNode, NodeResponse nodeResponse) { + synchronized (responses) { + responses.add(nodeResponse); + } + } + + @Override + protected void onItemFailure(DiscoveryNode discoveryNode, Exception e) { + logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, discoveryNode), e); + synchronized (exceptions) { + exceptions.add(new FailedNodeException(discoveryNode.getId(), "Failed node [" + discoveryNode.getId() + "]", e)); + } + } + + @Override + protected void onCompletion(ActionListener listener) { + // ref releases all happen-before here so no need to be synchronized + threadPool.executor(finalExecutor) + .execute(ActionRunnable.wrap(listener, l -> newResponseAsync(task, request, responses, exceptions, l))); + } + + @Override + public String toString() { + return actionName; + } + }.run(task, Iterators.forArray(request.concreteNodes()), listener); + } + + private Writeable.Reader nodeResponseReader(DiscoveryNode discoveryNode) { + // not an inline lambda to avoid capturing CancellableFanOut.this. + return in -> TransportNodesAction.this.newNodeResponse(in, discoveryNode); } /** diff --git a/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java b/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java new file mode 100644 index 0000000000000..a817faaf3b6b8 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.support; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelHelper; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.ReachabilityChecker; +import org.hamcrest.Matchers; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +public class CancellableFanOutTests extends ESTestCase { + + public void testFanOutWithoutCancellation() { + final var task = randomFrom( + new Task(1, "test", "test", "", TaskId.EMPTY_TASK_ID, Map.of()), + new CancellableTask(1, "test", "test", "", TaskId.EMPTY_TASK_ID, Map.of()), + null + ); + final var future = new PlainActionFuture(); + + final var itemListeners = new HashMap>(); + + new CancellableFanOut() { + int counter; + + @Override + protected void sendItemRequest(String item, ActionListener listener) { + itemListeners.put(item, listener); + } + + @Override + protected void onItemResponse(String item, String itemResponse) { + assertThat(item, Matchers.oneOf("a", "c")); + assertEquals(item + "-response", itemResponse); + counter += 1; + } + + @Override + protected void onItemFailure(String item, Exception e) { + assertEquals("b", item); + counter += 1; + } + + @Override + protected void onCompletion(ActionListener listener) { + assertEquals(3, counter); + listener.onResponse("completed"); + } + }.run(task, List.of("a", "b", "c").iterator(), future); + + itemListeners.remove("a").onResponse("a-response"); + assertFalse(future.isDone()); + itemListeners.remove("b").onFailure(new ElasticsearchException("b-response")); + assertFalse(future.isDone()); + itemListeners.remove("c").onResponse("c-response"); + assertTrue(future.isDone()); + assertEquals("completed", future.actionGet()); + } + + public void testReleaseOnCancellation() { + final var task = new CancellableTask(1, "test", "test", "", TaskId.EMPTY_TASK_ID, Map.of()); + final var future = new PlainActionFuture(); + + final var itemListeners = new HashMap>(); + final var handledItemResponse = new AtomicBoolean(); + + final var reachabilityChecker = new ReachabilityChecker(); + reachabilityChecker.register(new CancellableFanOut() { + @Override + protected void sendItemRequest(String item, ActionListener listener) { + itemListeners.put(item, listener); + } + + @Override + protected void onItemResponse(String item, String itemResponse) { + assertEquals("a", item); + assertEquals("a-response", itemResponse); + assertTrue(handledItemResponse.compareAndSet(false, true)); + } + + @Override + protected void onItemFailure(String item, Exception e) { + fail(item); + } + + @Override + protected void onCompletion(ActionListener listener) { + fail("onCompletion"); + } + }).run(task, List.of("a", "b", "c").iterator(), future); + + itemListeners.remove("a").onResponse("a-response"); + assertTrue(handledItemResponse.get()); + reachabilityChecker.checkReachable(); + + TaskCancelHelper.cancel(task, "test"); + reachabilityChecker.ensureUnreachable(); // even though we're still holding on to some item listeners. + assertFalse(future.isDone()); + + itemListeners.remove("b").onResponse("b-response"); + assertFalse(future.isDone()); + + itemListeners.remove("c").onFailure(new ElasticsearchException("c-response")); + assertTrue(itemListeners.isEmpty()); + assertTrue(future.isDone()); + expectThrows(TaskCancelledException.class, future::actionGet); + } +} From 10c19563645f7b01755b0c3e95fb749af9871c38 Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 26 May 2023 12:19:54 +0100 Subject: [PATCH 2/3] =?UTF-8?q?An=20async=20method=20that=20must=20complet?= =?UTF-8?q?e=20its=20listener=20is=20just=20a=20sync=20method=20?= =?UTF-8?q?=F0=9F=A4=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../action/support/CancellableFanOut.java | 10 +++------- .../node/TransportBroadcastByNodeAction.java | 15 +++++++++------ .../support/nodes/TransportNodesAction.java | 14 +++++++++----- .../support/CancellableFanOutTests.java | 19 ++++++++++++++----- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java b/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java index dfe4ebe24ae52..c307e1e2cb91d 100644 --- a/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java +++ b/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java @@ -56,11 +56,7 @@ public final void run(@Nullable Task task, Iterator itemsIterator, ActionL if (cancellableTask != null && cancellableTask.notifyIfCancelled(resultListener)) { return; } - onCompletion(resultListener); - - // It's important that onCompletion() completes resultListener before returning, because otherwise there's a risk that - // a cancellation arrives later which might unexpectedly complete the final listener on a transport thread. - assert resultListener.isDone() : "onCompletion did not complete its listener"; + ActionListener.completeWith(resultListener, this::onCompletion); }); // Collects the per-item listeners up so they can all be completed exceptionally on cancellation. Never completed successfully. @@ -152,10 +148,10 @@ public String toString() { /** * Called when responses for all items have been processed, on the thread that processed the last per-item response. Not called if the - * task is cancelled. Must complete the listener before returning. + * task is cancelled. *

* Note that it's easy to accidentally capture another reference to this class when implementing this method, and that will prevent the * early release of any accumulated results. Beware of lambdas, and test carefully. */ - protected abstract void onCompletion(ActionListener listener); + protected abstract FinalResponse onCompletion() throws Exception; } diff --git a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java index 25b1060461867..aec75e3300481 100644 --- a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java @@ -346,10 +346,14 @@ protected void onItemFailure(Map.Entry> entry, Except } @Override - protected void onCompletion(ActionListener listener) { + protected Response onCompletion() { // ref releases all happen-before here so no need to be synchronized - listener.onResponse( - responseFactory.newResponse(totalShards.get(), successfulShards.get(), exceptions.size(), shardResponses, exceptions) + return responseFactory.newResponse( + totalShards.get(), + successfulShards.get(), + exceptions.size(), + shardResponses, + exceptions ); } @@ -420,10 +424,9 @@ protected void onItemFailure(ShardRouting shardRouting, Exception e) { } @Override - protected void onCompletion(ActionListener listener) { + protected NodeResponse onCompletion() { // ref releases all happen-before here so no need to be synchronized - listener.onResponse(new NodeResponse(nodeId, shards.size(), results, exceptions)); - + return new NodeResponse(nodeId, shards.size(), results, exceptions); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java index 7568ad2268bd0..fedd357501ac1 100644 --- a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java @@ -23,6 +23,7 @@ import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportChannel; @@ -129,7 +130,7 @@ protected void doExecute(Task task, NodesRequest request, ActionListener() { + new CancellableFanOut, Exception>>() { final ArrayList responses = new ArrayList<>(request.concreteNodes().length); final ArrayList exceptions = new ArrayList<>(0); @@ -168,17 +169,20 @@ protected void onItemFailure(DiscoveryNode discoveryNode, Exception e) { } @Override - protected void onCompletion(ActionListener listener) { + protected CheckedConsumer, Exception> onCompletion() { // ref releases all happen-before here so no need to be synchronized - threadPool.executor(finalExecutor) - .execute(ActionRunnable.wrap(listener, l -> newResponseAsync(task, request, responses, exceptions, l))); + return l -> newResponseAsync(task, request, responses, exceptions, l); } @Override public String toString() { return actionName; } - }.run(task, Iterators.forArray(request.concreteNodes()), listener); + }.run( + task, + Iterators.forArray(request.concreteNodes()), + listener.delegateFailure((l, r) -> threadPool.executor(finalExecutor).execute(ActionRunnable.wrap(l, r))) + ); } private Writeable.Reader nodeResponseReader(DiscoveryNode discoveryNode) { diff --git a/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java b/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java index a817faaf3b6b8..db48b09e95a08 100644 --- a/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java @@ -35,6 +35,7 @@ public void testFanOutWithoutCancellation() { final var future = new PlainActionFuture(); final var itemListeners = new HashMap>(); + final var finalFailure = randomBoolean(); new CancellableFanOut() { int counter; @@ -58,9 +59,13 @@ protected void onItemFailure(String item, Exception e) { } @Override - protected void onCompletion(ActionListener listener) { + protected String onCompletion() { assertEquals(3, counter); - listener.onResponse("completed"); + if (finalFailure) { + throw new ElasticsearchException("failed"); + } else { + return "completed"; + } } }.run(task, List.of("a", "b", "c").iterator(), future); @@ -70,7 +75,11 @@ protected void onCompletion(ActionListener listener) { assertFalse(future.isDone()); itemListeners.remove("c").onResponse("c-response"); assertTrue(future.isDone()); - assertEquals("completed", future.actionGet()); + if (finalFailure) { + assertEquals("failed", expectThrows(ElasticsearchException.class, future::actionGet).getMessage()); + } else { + assertEquals("completed", future.actionGet()); + } } public void testReleaseOnCancellation() { @@ -100,8 +109,8 @@ protected void onItemFailure(String item, Exception e) { } @Override - protected void onCompletion(ActionListener listener) { - fail("onCompletion"); + protected String onCompletion() { + throw new AssertionError("onCompletion"); } }).run(task, List.of("a", "b", "c").iterator(), future); From d3cbd49dbce050f70a233d6078706aee5b5fcb5a Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 26 May 2023 12:21:52 +0100 Subject: [PATCH 3/3] Comment --- .../org/elasticsearch/action/support/CancellableFanOut.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java b/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java index c307e1e2cb91d..7c6f022eb7e83 100644 --- a/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java +++ b/server/src/main/java/org/elasticsearch/action/support/CancellableFanOut.java @@ -56,6 +56,8 @@ public final void run(@Nullable Task task, Iterator itemsIterator, ActionL if (cancellableTask != null && cancellableTask.notifyIfCancelled(resultListener)) { return; } + // It's important that we complete resultListener before returning, because otherwise there's a risk that a cancellation arrives + // later which might unexpectedly complete the final listener on a transport thread. ActionListener.completeWith(resultListener, this::onCompletion); });