diff --git a/server/src/main/java/org/elasticsearch/action/support/NodeResponseTracker.java b/server/src/main/java/org/elasticsearch/action/support/NodeResponseTracker.java deleted file mode 100644 index aafd6166cb364..0000000000000 --- a/server/src/main/java/org/elasticsearch/action/support/NodeResponseTracker.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.support; - -import java.util.Collection; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReferenceArray; - -/** - * This class tracks the intermediate responses that will be used to create aggregated cluster response to a request. It also gives the - * possibility to discard the intermediate results when asked, for example when the initial request is cancelled, in order to release the - * resources. - */ -public class NodeResponseTracker { - - private final AtomicInteger counter = new AtomicInteger(); - private final int expectedResponsesCount; - private volatile AtomicReferenceArray responses; - private volatile Exception causeOfDiscarding; - - public NodeResponseTracker(int size) { - this.expectedResponsesCount = size; - this.responses = new AtomicReferenceArray<>(size); - } - - public NodeResponseTracker(Collection array) { - this.expectedResponsesCount = array.size(); - this.responses = new AtomicReferenceArray<>(array.toArray()); - } - - /** - * This method discards the results collected so far to free up the resources. - * @param cause the discarding, this will be communicated if they try to access the discarded results - */ - public void discardIntermediateResponses(Exception cause) { - if (responses != null) { - this.causeOfDiscarding = cause; - responses = null; - } - } - - public boolean responsesDiscarded() { - return responses == null; - } - - /** - * This method stores a new node response if the intermediate responses haven't been discarded yet. If the responses are not discarded - * the method asserts that this is the first response encountered from this node to protect from miscounting the responses in case of a - * double invocation. If the responses have been discarded we accept this risk for simplicity. - * @param nodeIndex, the index that represents a single node of the cluster - * @param response, a response can be either a NodeResponse or an error - * @return true if all the nodes' responses have been received, else false - */ - public boolean trackResponseAndCheckIfLast(int nodeIndex, Object response) { - AtomicReferenceArray responses = this.responses; - - if (responsesDiscarded() == false) { - boolean firstEncounter = responses.compareAndSet(nodeIndex, null, response); - assert firstEncounter : "a response should be tracked only once"; - } - return counter.incrementAndGet() == getExpectedResponseCount(); - } - - /** - * Returns the tracked response or null if the response hasn't been received yet for a specific index that represents a node of the - * cluster. - * @throws DiscardedResponsesException if the responses have been discarded - */ - public Object getResponse(int nodeIndex) throws DiscardedResponsesException { - AtomicReferenceArray responses = this.responses; - if (responsesDiscarded()) { - throw new DiscardedResponsesException(causeOfDiscarding); - } - return responses.get(nodeIndex); - } - - public int getExpectedResponseCount() { - return expectedResponsesCount; - } - - /** - * This exception is thrown when the {@link NodeResponseTracker} is asked to give information about the responses after they have been - * discarded. - */ - public static class DiscardedResponsesException extends Exception { - - public DiscardedResponsesException(Exception cause) { - super(cause); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java index 82cc91e620d7e..f057d3e671a4b 100644 --- a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java @@ -8,16 +8,21 @@ package org.elasticsearch.action.support.broadcast.node; +import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.DefaultShardOperationFailedException; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.IndicesOptions; -import org.elasticsearch.action.support.NodeResponseTracker; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.action.support.broadcast.BaseBroadcastResponse; import org.elasticsearch.action.support.broadcast.BroadcastRequest; @@ -25,7 +30,6 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardsIterator; @@ -33,18 +37,16 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.util.concurrent.ListenableFuture; +import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.transport.TransportChannel; -import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; -import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import java.io.IOException; @@ -53,7 +55,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Consumer; import static org.elasticsearch.core.Strings.format; @@ -121,47 +122,6 @@ public TransportBroadcastByNodeAction( ); } - private Response newResponse( - NodeResponseTracker nodeResponseTracker, - int unavailableShardCount, - Map> nodes, - ResponseFactory responseFactory - ) throws NodeResponseTracker.DiscardedResponsesException { - int totalShards = 0; - int successfulShards = 0; - List broadcastByNodeResponses = new ArrayList<>(); - List exceptions = new ArrayList<>(); - for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); i++) { - Object response = nodeResponseTracker.getResponse(i); - if (response instanceof FailedNodeException exception) { - totalShards += nodes.get(exception.nodeId()).size(); - for (ShardRouting shard : nodes.get(exception.nodeId())) { - exceptions.add(new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), exception)); - } - } else { - @SuppressWarnings("unchecked") - NodeResponse nodeResponse = (NodeResponse) response; - broadcastByNodeResponses.addAll(nodeResponse.results); - totalShards += nodeResponse.getTotalShards(); - successfulShards += nodeResponse.getSuccessfulShards(); - for (BroadcastShardOperationFailedException throwable : nodeResponse.getExceptions()) { - if (TransportActions.isShardNotAvailableException(throwable) == false) { - exceptions.add( - new DefaultShardOperationFailedException( - throwable.getShardId().getIndexName(), - throwable.getShardId().getId(), - throwable - ) - ); - } - } - } - } - totalShards += unavailableShardCount; - int failedShards = exceptions.size(); - return responseFactory.newResponse(totalShards, successfulShards, failedShards, broadcastByNodeResponses, exceptions); - } - /** * Deserialize a shard-level result from an input stream * @@ -263,344 +223,308 @@ protected String[] resolveConcreteIndexNames(ClusterState clusterState, Request @Override protected void doExecute(Task task, Request request, ActionListener listener) { final var clusterState = clusterService.state(); - final var responseFactory = getResponseFactory(request, clusterState); - new AsyncAction(task, request, clusterState, responseFactory, listener).start(); - } - - protected class AsyncAction implements CancellableTask.CancellationListener { - private final Task task; - private final Request request; - private final ActionListener listener; - private final DiscoveryNodes nodes; - private final Map> nodeIds; - private final int unavailableShardCount; - private final NodeResponseTracker nodeResponseTracker; - private final ResponseFactory responseFactory; - - protected AsyncAction( - Task task, - Request request, - ClusterState clusterState, - ResponseFactory responseFactory, - ActionListener listener - ) { - this.task = task; - this.request = request; - this.listener = listener; - this.responseFactory = responseFactory; - - nodes = clusterState.nodes(); - - ClusterBlockException globalBlockException = checkGlobalBlock(clusterState, request); - if (globalBlockException != null) { - throw globalBlockException; - } - String[] concreteIndices = resolveConcreteIndexNames(clusterState, request); - ClusterBlockException requestBlockException = checkRequestBlock(clusterState, request, concreteIndices); - if (requestBlockException != null) { - throw requestBlockException; - } + final var globalBlockException = checkGlobalBlock(clusterState, request); + if (globalBlockException != null) { + throw globalBlockException; + } - if (logger.isTraceEnabled()) { - logger.trace("resolving shards for [{}] based on cluster state version [{}]", actionName, clusterState.version()); - } - ShardsIterator shardIt = shards(clusterState, request, concreteIndices); - nodeIds = new HashMap<>(); - - int unavailableShardCount = 0; - for (ShardRouting shard : shardIt) { - // send a request to the shard only if it is assigned to a node that is in the local node's cluster state - // a scenario in which a shard can be assigned but to a node that is not in the local node's cluster state - // is when the shard is assigned to the master node, the local node has detected the master as failed - // and a new master has not yet been elected; in this situation the local node will have removed the - // master node from the local cluster state, but the shards assigned to the master will still be in the - // routing table as such - if (shard.assignedToNode() && nodes.get(shard.currentNodeId()) != null) { - String nodeId = shard.currentNodeId(); - if (nodeIds.containsKey(nodeId) == false) { - nodeIds.put(nodeId, new ArrayList<>()); - } - nodeIds.get(nodeId).add(shard); - } else { - unavailableShardCount++; - } + final var concreteIndices = resolveConcreteIndexNames(clusterState, request); + final var requestBlockException = checkRequestBlock(clusterState, request, concreteIndices); + if (requestBlockException != null) { + throw requestBlockException; + } + logger.trace(() -> format("resolving shards for [%s] based on cluster state version [%s]", actionName, clusterState.version())); + final ShardsIterator shardIt = shards(clusterState, request, concreteIndices); + final Map> shardsByNodeId = new HashMap<>(); + + final var nodes = clusterState.nodes(); + int unavailableShardCount = 0; + int availableShardCount = 0; + for (final var shard : shardIt) { + // send a request to the shard only if it is assigned to a node that is in the local node's cluster state + // a scenario in which a shard can be assigned but to a node that is not in the local node's cluster state + // is when the shard is assigned to the master node, the local node has detected the master as failed + // and a new master has not yet been elected; in this situation the local node will have removed the + // master node from the local cluster state, but the shards assigned to the master will still be in the + // routing table as such + final var nodeId = shard.currentNodeId(); + if (nodeId != null && nodes.get(nodeId) != null) { + shardsByNodeId.computeIfAbsent(nodeId, n -> new ArrayList<>()).add(shard); + availableShardCount += 1; + } else { + unavailableShardCount++; } - this.unavailableShardCount = unavailableShardCount; - nodeResponseTracker = new NodeResponseTracker(nodeIds.size()); } - public void start() { + executeAsCoordinatingNode( + task, + request, + shardsByNodeId, + unavailableShardCount, + availableShardCount, + nodes, + getResponseFactory(request, clusterState), + listener + ); + } + + private void executeAsCoordinatingNode( + Task task, + Request request, + Map> shardsByNodeId, + int unavailableShardCount, + int availableShardCount, + DiscoveryNodes nodes, + ResponseFactory responseFactory, + ActionListener listener + ) { + final var mutex = new Object(); + final var shardResponses = new ArrayList(availableShardCount); + final var exceptions = new ArrayList(0); + final var totalShards = new AtomicInteger(unavailableShardCount); + final var successfulShards = new AtomicInteger(0); + + final var resultListener = new ListenableFuture(); + final var resultListenerCompleter = new RunOnce(() -> { if (task instanceof CancellableTask cancellableTask) { - cancellableTask.addListener(this); - } - if (nodeIds.size() == 0) { - ActionListener.run(listener, ignored -> onCompletion()); - } else { - int nodeIndex = -1; - for (Map.Entry> entry : nodeIds.entrySet()) { - nodeIndex++; - DiscoveryNode node = nodes.get(entry.getKey()); - sendNodeRequest(node, entry.getValue(), nodeIndex); + if (cancellableTask.notifyIfCancelled(resultListener)) { + return; } } + // ref releases all happen-before here so no need to be synchronized + resultListener.onResponse( + responseFactory.newResponse(totalShards.get(), successfulShards.get(), exceptions.size(), shardResponses, exceptions) + ); + }); + + final var nodeFailureListeners = new ListenableFuture(); + if (task instanceof CancellableTask cancellableTask) { + cancellableTask.addListener(() -> { + assert cancellableTask.isCancelled(); + resultListenerCompleter.run(); + cancellableTask.notifyIfCancelled(nodeFailureListeners); + }); } - private void sendNodeRequest(final DiscoveryNode node, List shards, final int nodeIndex) { - try { - final NodeRequest nodeRequest = new NodeRequest(request, shards, node.getId()); - if (task != null) { - nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId()); - } + final var transportRequestOptions = TransportRequestOptions.timeout(request.timeout()); - final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout()); + try (var refs = new RefCountingRunnable(() -> { + resultListener.addListener(listener); + resultListenerCompleter.run(); + })) { + for (final var entry : shardsByNodeId.entrySet()) { + final var node = nodes.get(entry.getKey()); + final var shards = entry.getValue(); - transportService.sendRequest( - node, - transportNodeBroadcastAction, - nodeRequest, - transportRequestOptions, - new TransportResponseHandler() { - @Override - public NodeResponse read(StreamInput in) throws IOException { - return new NodeResponse(in); + final ActionListener nodeResponseListener = ActionListener.notifyOnce(new ActionListener() { + @Override + public void onResponse(NodeResponse nodeResponse) { + synchronized (mutex) { + shardResponses.addAll(nodeResponse.getResults()); } - - @Override - public void handleResponse(NodeResponse response) { - onNodeResponse(node, nodeIndex, response); + totalShards.addAndGet(nodeResponse.getTotalShards()); + successfulShards.addAndGet(nodeResponse.getSuccessfulShards()); + + for (BroadcastShardOperationFailedException exception : nodeResponse.getExceptions()) { + if (TransportActions.isShardNotAvailableException(exception)) { + assert node.getVersion().before(Version.V_8_7_0) : node; // we stopped sending these ignored exceptions + } else { + synchronized (mutex) { + exceptions.add( + new DefaultShardOperationFailedException( + exception.getShardId().getIndexName(), + exception.getShardId().getId(), + exception + ) + ); + } + } } + } - @Override - public void handleException(TransportException exp) { - onNodeFailure(node, nodeIndex, exp); + @Override + public void onFailure(Exception e) { + if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) { + return; } - } - ); - } catch (Exception e) { - onNodeFailure(node, nodeIndex, e); - } - } - protected void onNodeResponse(DiscoveryNode node, int nodeIndex, NodeResponse response) { - if (logger.isTraceEnabled()) { - logger.trace("received response for [{}] from node [{}]", actionName, node.getId()); - } + logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, node), e); - if (nodeResponseTracker.trackResponseAndCheckIfLast(nodeIndex, response)) { - onCompletion(); - } - } + final var failedNodeException = new FailedNodeException(node.getId(), "Failed node [" + node.getId() + "]", e); + synchronized (mutex) { + for (ShardRouting shard : shards) { + exceptions.add( + new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), failedNodeException) + ); + } + } - protected void onNodeFailure(DiscoveryNode node, int nodeIndex, Throwable t) { - String nodeId = node.getId(); - logger.debug(() -> format("failed to execute [%s] on node [%s]", actionName, nodeId), t); - if (nodeResponseTracker.trackResponseAndCheckIfLast( - nodeIndex, - new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t) - )) { - onCompletion(); - } - } + totalShards.addAndGet(shards.size()); + } - protected void onCompletion() { - if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) { - return; - } + @Override + public String toString() { + return "[" + actionName + "][" + node.descriptionWithoutAttributes() + "]"; + } + }); - Response response = null; - try { - response = newResponse(nodeResponseTracker, unavailableShardCount, nodeIds, responseFactory); - } catch (NodeResponseTracker.DiscardedResponsesException e) { - // We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take - // follow-up actions - listener.onFailure((Exception) e.getCause()); - } catch (Exception e) { - logger.debug("failed to combine responses from nodes", e); - listener.onFailure(e); - } - if (response != null) { - try { - listener.onResponse(response); - } catch (Exception e) { - listener.onFailure(e); + if (task instanceof CancellableTask) { + nodeFailureListeners.addListener(nodeResponseListener); } - } - } - @Override - public void onCancelled() { - assert task instanceof CancellableTask : "task must be cancellable"; - try { - ((CancellableTask) task).ensureNotCancelled(); - } catch (TaskCancelledException e) { - nodeResponseTracker.discardIntermediateResponses(e); - } - } + final var nodeRequest = new NodeRequest(request, shards, node.getId()); + if (task != null) { + nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId()); + } - // For testing purposes - public NodeResponseTracker getNodeResponseTracker() { - return nodeResponseTracker; + transportService.sendRequest( + node, + transportNodeBroadcastAction, + nodeRequest, + transportRequestOptions, + new ActionListenerResponseHandler<>( + ActionListener.releaseAfter(nodeResponseListener, refs.acquire()), + NodeResponse::new + ) + ); + } } } class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler { @Override public void messageReceived(final NodeRequest request, TransportChannel channel, Task task) throws Exception { - List shards = request.getShards(); - final int totalShards = shards.size(); - if (logger.isTraceEnabled()) { - logger.trace("[{}] executing operation on [{}] shards", actionName, totalShards); + executeAsDataNode( + task, + request.getIndicesLevelRequest(), + request.getShards(), + request.getNodeId(), + new ChannelActionListener<>(channel, transportNodeBroadcastAction, request) + ); + } + } + + private void executeAsDataNode( + Task task, + Request request, + List shards, + String nodeId, + ActionListener listener + ) { + logger.trace("[{}] executing operation on [{}] shards", actionName, shards.size()); + + final var results = new ArrayList(shards.size()); + final var exceptions = new ArrayList(0); + + final var resultListener = new ListenableFuture(); + final var resultListenerCompleter = new RunOnce(() -> { + if (task instanceof CancellableTask cancellableTask) { + if (cancellableTask.notifyIfCancelled(resultListener)) { + return; + } } - final AtomicArray shardResultOrExceptions = new AtomicArray<>(totalShards); + // ref releases all happen-before here so no need to be synchronized + resultListener.onResponse(new NodeResponse(nodeId, shards.size(), results, exceptions)); + }); + + final var shardFailureListeners = new ListenableFuture(); + if (task instanceof CancellableTask cancellableTask) { + cancellableTask.addListener(() -> { + assert cancellableTask.isCancelled(); + resultListenerCompleter.run(); + cancellableTask.notifyIfCancelled(shardFailureListeners); + }); + } - final AtomicInteger counter = new AtomicInteger(shards.size()); - int shardIndex = -1; - for (final ShardRouting shardRouting : shards) { - shardIndex++; - final int finalShardIndex = shardIndex; - onShardOperation(request, shardRouting, task, ActionListener.notifyOnce(new ActionListener() { + try (var refs = new RefCountingRunnable(() -> { + resultListener.addListener(listener); + resultListenerCompleter.run(); + })) { + for (final var shardRouting : shards) { + if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) { + return; + } + final ActionListener shardListener = ActionListener.notifyOnce(new ActionListener<>() { @Override public void onResponse(ShardOperationResult shardOperationResult) { - shardResultOrExceptions.setOnce(finalShardIndex, shardOperationResult); - if (counter.decrementAndGet() == 0) { - finishHim(request, channel, task, shardResultOrExceptions); + logger.trace(() -> format("[%s] completed operation for shard [%s]", actionName, shardRouting.shortSummary())); + synchronized (results) { + results.add(shardOperationResult); } } @Override public void onFailure(Exception e) { - shardResultOrExceptions.setOnce(finalShardIndex, e); - if (counter.decrementAndGet() == 0) { - finishHim(request, channel, task, shardResultOrExceptions); + if (task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) { + return; } - } - })); - } - } - - @SuppressWarnings("unchecked") - private void finishHim(NodeRequest request, TransportChannel channel, Task task, AtomicArray shardResultOrExceptions) { - if (task instanceof CancellableTask) { - try { - ((CancellableTask) task).ensureNotCancelled(); - } catch (TaskCancelledException e) { - try { - channel.sendResponse(e); - } catch (IOException ioException) { - e.addSuppressed(ioException); - logger.warn("failed to send response", e); - } - return; - } - } - List accumulatedExceptions = new ArrayList<>(); - List results = new ArrayList<>(); - for (int i = 0; i < shardResultOrExceptions.length(); i++) { - if (shardResultOrExceptions.get(i) instanceof BroadcastShardOperationFailedException) { - accumulatedExceptions.add((BroadcastShardOperationFailedException) shardResultOrExceptions.get(i)); - } else { - results.add((ShardOperationResult) shardResultOrExceptions.get(i)); - } - } - - try { - channel.sendResponse( - new NodeResponse(request.getNodeId(), shardResultOrExceptions.length(), results, accumulatedExceptions) - ); - } catch (IOException e) { - logger.warn("failed to send response", e); - } - } - - private void onShardOperation( - final NodeRequest request, - final ShardRouting shardRouting, - final Task task, - final ActionListener listener - ) { - if (task instanceof CancellableTask && ((CancellableTask) task).notifyIfCancelled(listener)) { - return; - } - if (logger.isTraceEnabled()) { - logger.trace("[{}] executing operation for shard [{}]", actionName, shardRouting.shortSummary()); - } - final Consumer failureHandler = e -> { - BroadcastShardOperationFailedException failure = new BroadcastShardOperationFailedException( - shardRouting.shardId(), - "operation " + actionName + " failed", - e - ); - failure.setShard(shardRouting.shardId()); - if (TransportActions.isShardNotAvailableException(e)) { - if (logger.isTraceEnabled()) { - logger.trace( + logger.log( + TransportActions.isShardNotAvailableException(e) ? Level.TRACE : Level.DEBUG, () -> format("[%s] failed to execute operation for shard [%s]", actionName, shardRouting.shortSummary()), e ); - } - } else { - if (logger.isDebugEnabled()) { - logger.debug( - () -> format("[%s] failed to execute operation for shard [%s]", actionName, shardRouting.shortSummary()), - e - ); - } - } - listener.onFailure(failure); - }; - try { - shardOperation(request.getIndicesLevelRequest(), shardRouting, task, new ActionListener<>() { - @Override - public void onResponse(ShardOperationResult shardOperationResult) { - if (logger.isTraceEnabled()) { - logger.trace("[{}] completed operation for shard [{}]", actionName, shardRouting.shortSummary()); + if (TransportActions.isShardNotAvailableException(e) == false) { + synchronized (exceptions) { + exceptions.add( + new BroadcastShardOperationFailedException( + shardRouting.shardId(), + "operation " + actionName + " failed", + e + ) + ); + } } - listener.onResponse(shardOperationResult); } @Override - public void onFailure(Exception e) { - failureHandler.accept(e); + public String toString() { + return "[" + actionName + "][" + shardRouting + "]"; } }); - } catch (Exception e) { - assert false : "shardOperation should not throw an exception, but delegate to listener instead"; - failureHandler.accept(e); + + if (task instanceof CancellableTask) { + shardFailureListeners.addListener(shardListener); + } + + logger.trace(() -> format("[%s] executing operation for shard [%s]", actionName, shardRouting.shortSummary())); + ActionRunnable.wrap( + ActionListener.releaseAfter(shardListener, refs.acquire()), + l -> shardOperation(request, shardRouting, task, l) + ).run(); } } } - public class NodeRequest extends TransportRequest implements IndicesRequest { - + class NodeRequest extends TransportRequest implements IndicesRequest { private final Request indicesLevelRequest; private final List shards; private final String nodeId; - public NodeRequest(StreamInput in) throws IOException { + NodeRequest(StreamInput in) throws IOException { super(in); indicesLevelRequest = readRequestFrom(in); shards = in.readList(ShardRouting::new); nodeId = in.readString(); } - public NodeRequest(Request indicesLevelRequest, List shards, String nodeId) { + NodeRequest(Request indicesLevelRequest, List shards, String nodeId) { this.indicesLevelRequest = indicesLevelRequest; this.shards = shards; this.nodeId = nodeId; } - public List getShards() { + List getShards() { return shards; } - public String getNodeId() { + String getNodeId() { return nodeId; } - public Request getIndicesLevelRequest() { + Request getIndicesLevelRequest() { return indicesLevelRequest; } @@ -658,19 +582,23 @@ class NodeResponse extends TransportResponse { this.exceptions = exceptions; } - public String getNodeId() { + String getNodeId() { return nodeId; } - public int getTotalShards() { + int getTotalShards() { return totalShards; } - public int getSuccessfulShards() { + int getSuccessfulShards() { return results.size(); } - public List getExceptions() { + List getResults() { + return results; + } + + List getExceptions() { return exceptions; } diff --git a/server/src/test/java/org/elasticsearch/action/support/NodeResponseTrackerTests.java b/server/src/test/java/org/elasticsearch/action/support/NodeResponseTrackerTests.java deleted file mode 100644 index 11d2ee1f12a04..0000000000000 --- a/server/src/test/java/org/elasticsearch/action/support/NodeResponseTrackerTests.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.support; - -import org.elasticsearch.test.ESTestCase; - -public class NodeResponseTrackerTests extends ESTestCase { - - public void testAllResponsesReceived() throws Exception { - int nodes = randomIntBetween(1, 10); - NodeResponseTracker intermediateNodeResponses = new NodeResponseTracker(nodes); - for (int i = 0; i < nodes; i++) { - boolean isLast = i == nodes - 1; - assertEquals( - isLast, - intermediateNodeResponses.trackResponseAndCheckIfLast(i, randomBoolean() ? i : new Exception("from node " + i)) - ); - } - - assertFalse(intermediateNodeResponses.responsesDiscarded()); - assertEquals(nodes, intermediateNodeResponses.getExpectedResponseCount()); - for (int i = 0; i < nodes; i++) { - assertNotNull(intermediateNodeResponses.getResponse(i)); - if (intermediateNodeResponses.getResponse(i)instanceof Integer nodeResponse) { - assertEquals(i, nodeResponse.intValue()); - } - } - } - - public void testDiscardingResults() { - int nodes = randomIntBetween(1, 10); - int cancelAt = randomIntBetween(0, Math.max(0, nodes - 2)); - NodeResponseTracker intermediateNodeResponses = new NodeResponseTracker(nodes); - for (int i = 0; i < nodes; i++) { - if (i == cancelAt) { - intermediateNodeResponses.discardIntermediateResponses(new Exception("simulated")); - } - boolean isLast = i == nodes - 1; - assertEquals( - isLast, - intermediateNodeResponses.trackResponseAndCheckIfLast(i, randomBoolean() ? i : new Exception("from node " + i)) - ); - } - - assertTrue(intermediateNodeResponses.responsesDiscarded()); - assertEquals(nodes, intermediateNodeResponses.getExpectedResponseCount()); - expectThrows(NodeResponseTracker.DiscardedResponsesException.class, () -> intermediateNodeResponses.getResponse(0)); - } - - public void testResponseIsRegisteredOnlyOnce() { - NodeResponseTracker intermediateNodeResponses = new NodeResponseTracker(1); - assertTrue(intermediateNodeResponses.trackResponseAndCheckIfLast(0, "response1")); - expectThrows(AssertionError.class, () -> intermediateNodeResponses.trackResponseAndCheckIfLast(0, "response2")); - } -} diff --git a/server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java b/server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java index df4a26260f0a2..7e76190e60a00 100644 --- a/server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java @@ -633,4 +633,55 @@ action.new NodeRequest( ); } + public void testShardResultsReleasedOnCancellation() throws Exception { + final var listeners = new ArrayList>(); + + action = new TestTransportBroadcastByNodeAction("indices:admin/shard_level_gc_test") { + @Override + protected void shardOperation(Request request, ShardRouting shardRouting, Task task, ActionListener listener) { + listeners.add(listener); + } + }; + + final PlainActionFuture nodeResponseFuture = new PlainActionFuture<>(); + final CancellableTask task = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()); + + action.new BroadcastByNodeTransportRequestHandler().messageReceived( + action.new NodeRequest( + new Request(), IntStream.range(0, 3) + .mapToObj(shardId -> TestShardRouting.newShardRouting(TEST_INDEX, shardId, "node-id", true, ShardRoutingState.STARTED)) + .toList(), "node-id" + ), + new TestTransportChannel(nodeResponseFuture), + task + ); + + assertEquals(3, listeners.size()); + + final var reachabilityChecker = new ReachabilityChecker(); + listeners.get(0).onResponse(reachabilityChecker.register(new ShardResult())); + reachabilityChecker.checkReachable(); + + TaskCancelHelper.cancel(task, "simulated"); + reachabilityChecker.ensureUnreachable(); + + listeners.get(1).onResponse(reachabilityChecker.register(new ShardResult())); + reachabilityChecker.ensureUnreachable(); + + assertFalse(nodeResponseFuture.isDone()); + + listeners.get(2).onResponse(reachabilityChecker.register(new ShardResult())); + reachabilityChecker.ensureUnreachable(); + + assertTrue(nodeResponseFuture.isDone()); + assertEquals( + "task cancelled [simulated]", + expectThrows( + java.util.concurrent.ExecutionException.class, + org.elasticsearch.tasks.TaskCancelledException.class, + nodeResponseFuture::get + ).getMessage() + ); + } + }