From 1d15a29e35276fc3cf8fd9f1e7a5ec56216b0f4f Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 3 Jan 2023 08:34:16 +0000 Subject: [PATCH 1/5] Introduce RefCountingRunnable Today a `CountDownActionListener` which wraps a bare `Runnable` collects all the exceptions it receives only to drop them when finally completing the delegate action. Moreover callers must declare up-front the number of times the listener will be completed, which means they must put extra effort into computing this number ahead of time and/or supply an overestimate and then make up the difference with additional artificial completions. This commit introduces `RefCountingRunnable` which allows callers to acquire and release references as needed, executing the delegate `Runnable` once all references are released. It also refactors all the relevant call sites to use this new utility. --- .../elasticsearch/action/ActionListener.java | 31 +++ .../support/CountDownActionListener.java | 9 - .../action/support/RefCountingRunnable.java | 120 +++++++++ .../cluster/InternalClusterInfoService.java | 34 ++- .../cluster/NodeConnectionsService.java | 62 ++--- .../allocation/DiskThresholdMonitor.java | 237 +++++++++--------- .../elasticsearch/ingest/IngestService.java | 71 +++--- .../blobstore/BlobStoreRepository.java | 42 ++-- .../snapshots/RestoreService.java | 78 +++--- .../action/ActionListenerTests.java | 83 ++++++ .../support/CountDownActionListenerTests.java | 1 - .../support/RefCountingRunnableTests.java | 234 +++++++++++++++++ .../allocation/DiskThresholdMonitorTests.java | 27 +- .../ShardSnapshotTaskRunnerTests.java | 11 +- .../snapshots/RestoreServiceTests.java | 18 +- .../testkit/RepositoryAnalyzeAction.java | 68 ++--- 16 files changed, 755 insertions(+), 371 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java create mode 100644 server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java diff --git a/server/src/main/java/org/elasticsearch/action/ActionListener.java b/server/src/main/java/org/elasticsearch/action/ActionListener.java index 6a41a8205b783..7c682ae11723f 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/ActionListener.java @@ -13,6 +13,8 @@ import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.core.CheckedRunnable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import java.util.ArrayList; import java.util.List; @@ -252,6 +254,13 @@ public String toString() { } } + /** + * Creates a listener which releases the given resource on completion (whether success or failure) + */ + static ActionListener releasing(Releasable releasable) { + return wrap(runnableFromReleasable(releasable)); + } + /** * Creates a listener that listens for a response (or failure) and executes the * corresponding runnable when the response (or failure) is received. @@ -335,6 +344,14 @@ static ActionListener runAfter(ActionListener del return new RunAfterActionListener<>(delegate, runAfter); } + /** + * Wraps a given listener and returns a new listener which releases the provided {@code releaseAfter} + * resource when the listener is notified via either {@code #onResponse} or {@code #onFailure}. + */ + static ActionListener releaseAfter(ActionListener delegate, Releasable releaseAfter) { + return new RunAfterActionListener<>(delegate, runnableFromReleasable(releaseAfter)); + } + final class RunAfterActionListener extends Delegating { private final Runnable runAfter; @@ -471,4 +488,18 @@ static void completeWith(ActionListener listener, CheckedSu throw ex; } } + + private static Runnable runnableFromReleasable(Releasable releasable) { + return new Runnable() { + @Override + public void run() { + Releasables.closeExpectNoException(releasable); + } + + @Override + public String toString() { + return "release[" + releasable + "]"; + } + }; + } } diff --git a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java index e9da843d34c25..2dac0f4c8cb5f 100644 --- a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java @@ -37,15 +37,6 @@ public CountDownActionListener(int groupSize, ActionListener delegate) { countDown = new AtomicInteger(groupSize); } - /** - * Creates a new listener - * @param groupSize the group size - * @param runnable the runnable - */ - public CountDownActionListener(int groupSize, Runnable runnable) { - this(groupSize, ActionListener.wrap(Objects.requireNonNull(runnable))); - } - private boolean countDown() { final var result = countDown.getAndUpdate(current -> Math.max(0, current - 1)); assert result > 0; diff --git a/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java b/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java new file mode 100644 index 0000000000000..b78fe212386b4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.support; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; + +import java.util.Objects; + +/** + * A mechanism to trigger an action on the completion of some (dynamic) collection of other actions. Basic usage is as follows: + * + *
+ * try (var refs = new RefCountingRunnable(finalRunnable)) {
+ *     for (var item : collection) {
+ *         runAsyncAction(item, refs.acquire()); // releases the acquired ref on completion
+ *     }
+ * }
+ * 
+ * + * The delegate action is completed when execution leaves the try-with-resources block and every acquired reference is released. Unlike a + * {@link CountDown} there is no need to declare the number of subsidiary actions up front (refs can be acquired dynamically as needed) nor + * does the caller need to check for completion each time a reference is released. Moreover even outside the try-with-resources block you + * can continue to acquire additional listeners, even in a separate thread, as long as there's at least one listener outstanding: + * + *
+ * try (var refs = new RefCountingRunnable(finalRunnable)) {
+ *     for (var item : collection) {
+ *         if (condition(item)) {
+ *             runAsyncAction(item, refs.acquire());
+ *         }
+ *     }
+ *     if (flag) {
+ *         runOneOffAsyncAction(refs.acquire());
+ *         return;
+ *     }
+ *     for (var item : otherCollection) {
+ *         var itemRef = refs.acquire(); // delays completion while the background action is pending
+ *         executorService.execute(() -> {
+ *             try (var ignored = itemRef) {
+ *                 if (condition(item)) {
+ *                     runOtherAsyncAction(item, refs.acquire());
+ *                 }
+ *             }
+ *         });
+ *     }
+ * }
+ * 
+ * + * In particular (and also unlike a {@link CountDown}) this works even if you don't acquire any extra refs at all: in that case, the + * delegate action executes at the end of the try-with-resources block. + */ +public final class RefCountingRunnable implements Releasable { + + private static final Logger logger = LogManager.getLogger(RefCountingRunnable.class); + static final String ALREADY_CLOSED_MESSAGE = "already closed, cannot acquire or release any further refs"; + + private final Runnable delegate; // TODO drop this when #92616 merged + private final RefCounted refCounted; + + /** + * Construct a {@link RefCountingRunnable} which executes {@code delegate} when all refs are released. + * @param delegate The action to execute when all refs are released. This action must not throw any exception. + */ + public RefCountingRunnable(Runnable delegate) { + this.delegate = Objects.requireNonNull(delegate); + this.refCounted = AbstractRefCounted.of(delegate); + } + + /** + * Acquire a reference to this object and return an action which releases it. The delegate {@link Runnable} is called when all its + * references have been released. + */ + public Releasable acquire() { + if (refCounted.tryIncRef()) { + // closing ourselves releases a ref, so we can just return 'this' and avoid any allocation; callers only see a Releasable + return this; + } + assert false : ALREADY_CLOSED_MESSAGE; + throw new IllegalStateException(ALREADY_CLOSED_MESSAGE); + } + + /** + * Acquire a reference to this object and return a listener which releases it when notified. The delegate {@link Runnable} is called + * when all its references have been released. + */ + public ActionListener acquireListener() { + return ActionListener.releasing(acquire()); + } + + /** + * Release a reference to this object, and execute the delegate {@link Runnable} if there are no other references. + */ + @Override + public void close() { + try { + refCounted.decRef(); + } catch (Exception e) { + logger.error("exception in delegate", e); + assert false : e; + } + } + + @Override + public String toString() { + return "refCounted[" + delegate + "]"; // TODO refCounted.toString() when #92616 merged + } + +} diff --git a/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java b/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java index c36ee576923a8..113062ee4d4d2 100644 --- a/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java +++ b/server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.admin.indices.stats.ShardStats; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.block.ClusterBlockException; @@ -32,7 +33,6 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.store.StoreStats; @@ -161,7 +161,7 @@ public void clusterChanged(ClusterChangedEvent event) { private class AsyncRefresh { private final List> thisRefreshListeners; - private final CountDown countDown = new CountDown(2); + private final RefCountingRunnable fetchRefs = new RefCountingRunnable(this::callListeners); AsyncRefresh(List> thisRefreshListeners) { this.thisRefreshListeners = thisRefreshListeners; @@ -177,15 +177,15 @@ void execute() { return; } - assert countDown.isCountedDown() == false; logger.trace("starting async refresh"); - try (var ignored = threadPool.getThreadContext().clearTraceContext()) { - fetchNodeStats(); - } - - try (var ignored = threadPool.getThreadContext().clearTraceContext()) { - fetchIndicesStats(); + try (var ignoredRefs = fetchRefs) { + try (var ignored = threadPool.getThreadContext().clearTraceContext()) { + fetchNodeStats(); + } + try (var ignored = threadPool.getThreadContext().clearTraceContext()) { + fetchIndicesStats(); + } } } @@ -203,7 +203,7 @@ private void fetchIndicesStats() { logger, threadPool, ThreadPool.Names.MANAGEMENT, - ActionListener.runAfter(new ActionListener<>() { + ActionListener.releaseAfter(new ActionListener<>() { @Override public void onResponse(IndicesStatsResponse indicesStatsResponse) { logger.trace("received indices stats response"); @@ -277,7 +277,7 @@ public void onFailure(Exception e) { } indicesStatsSummary = IndicesStatsSummary.EMPTY; } - }, this::onStatsProcessed), + }, fetchRefs.acquire()), false ) ); @@ -288,7 +288,7 @@ private void fetchNodeStats() { nodesStatsRequest.clear(); nodesStatsRequest.addMetric(NodesStatsRequest.Metric.FS.metricName()); nodesStatsRequest.timeout(fetchTimeout); - client.admin().cluster().nodesStats(nodesStatsRequest, ActionListener.runAfter(new ActionListener<>() { + client.admin().cluster().nodesStats(nodesStatsRequest, ActionListener.releaseAfter(new ActionListener<>() { @Override public void onResponse(NodesStatsResponse nodesStatsResponse) { logger.trace("received node stats response"); @@ -318,18 +318,12 @@ public void onFailure(Exception e) { leastAvailableSpaceUsages = Map.of(); mostAvailableSpaceUsages = Map.of(); } - }, this::onStatsProcessed)); - } - - private void onStatsProcessed() { - if (countDown.countDown()) { - logger.trace("stats all received, computing cluster info and notifying listeners"); - callListeners(); - } + }, fetchRefs.acquire())); } private void callListeners() { try { + logger.trace("stats all received, computing cluster info and notifying listeners"); final ClusterInfo clusterInfo = getClusterInfo(); boolean anyListeners = false; for (final Consumer listener : listeners) { diff --git a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java index 2e67288358c2b..1288fa10b72c7 100644 --- a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java +++ b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java @@ -11,7 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.coordination.FollowersChecker; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -28,7 +28,6 @@ import org.elasticsearch.transport.TransportService; import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -97,28 +96,25 @@ public void connectToNodes(DiscoveryNodes discoveryNodes, Runnable onCompletion) return; } - final CountDownActionListener listener = new CountDownActionListener(discoveryNodes.getSize(), onCompletion); - final List runnables = new ArrayList<>(discoveryNodes.getSize()); - synchronized (mutex) { - for (final DiscoveryNode discoveryNode : discoveryNodes) { - ConnectionTarget connectionTarget = targetsByNode.get(discoveryNode); - final boolean isNewNode = connectionTarget == null; - if (isNewNode) { - connectionTarget = new ConnectionTarget(discoveryNode); - targetsByNode.put(discoveryNode, connectionTarget); - } + try (var refs = new RefCountingRunnable(onCompletion)) { + synchronized (mutex) { + for (final DiscoveryNode discoveryNode : discoveryNodes) { + ConnectionTarget connectionTarget = targetsByNode.get(discoveryNode); + final boolean isNewNode = connectionTarget == null; + if (isNewNode) { + connectionTarget = new ConnectionTarget(discoveryNode); + targetsByNode.put(discoveryNode, connectionTarget); + } - if (isNewNode) { - logger.debug("connecting to {}", discoveryNode); - runnables.add( - connectionTarget.connect(ActionListener.runAfter(listener, () -> logger.debug("connected to {}", discoveryNode))) - ); - } else { - // known node, try and ensure it's connected but do not wait - logger.trace("checking connection to existing node [{}]", discoveryNode); - runnables.add(connectionTarget.connect(null)); - runnables.add(() -> listener.onResponse(null)); + if (isNewNode) { + logger.debug("connecting to {}", discoveryNode); + runnables.add(connectionTarget.connect(refs.acquire())); + } else { + // known node, try and ensure it's connected but do not wait + logger.trace("checking connection to existing node [{}]", discoveryNode); + runnables.add(connectionTarget.connect(null)); + } } } } @@ -150,15 +146,11 @@ public void disconnectFromNodesExcept(DiscoveryNodes discoveryNodes) { */ void ensureConnections(Runnable onCompletion) { final List runnables = new ArrayList<>(); - synchronized (mutex) { - final Collection connectionTargets = targetsByNode.values(); - if (connectionTargets.isEmpty()) { - runnables.add(onCompletion); - } else { + try (var refs = new RefCountingRunnable(onCompletion)) { + synchronized (mutex) { logger.trace("ensureConnections: {}", targetsByNode); - final CountDownActionListener listener = new CountDownActionListener(connectionTargets.size(), onCompletion); - for (final ConnectionTarget connectionTarget : connectionTargets) { - runnables.add(connectionTarget.connect(listener)); + for (ConnectionTarget connectionTarget : targetsByNode.values()) { + runnables.add(connectionTarget.connect(refs.acquire())); } } } @@ -227,7 +219,7 @@ private void setConnectionRef(Releasable connectionReleasable) { Releasables.close(connectionRef.getAndSet(connectionReleasable)); } - Runnable connect(ActionListener listener) { + Runnable connect(Releasable onCompletion) { return () -> { final boolean alreadyConnected = transportService.nodeConnected(discoveryNode); @@ -258,9 +250,7 @@ public void onResponse(Releasable connectionReleasable) { logger.debug("connected to stale {} - releasing stale connection", discoveryNode); setConnectionRef(null); } - if (listener != null) { - listener.onResponse(null); - } + Releasables.closeExpectNoException(onCompletion); } @Override @@ -274,9 +264,7 @@ public void onFailure(Exception e) { e ); setConnectionRef(null); - if (listener != null) { - listener.onFailure(e); - } + Releasables.closeExpectNoException(onCompletion); } }); }; diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java index 82413fe3723ae..86a8f20115c55 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java @@ -11,7 +11,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterInfo; import org.elasticsearch.cluster.ClusterState; @@ -31,6 +32,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.Releasable; import java.util.ArrayList; import java.util.Collections; @@ -301,123 +303,109 @@ public void onNewInfo(ClusterInfo info) { } } - final ActionListener listener = new CountDownActionListener(3, this::checkFinished); - - if (reroute) { - logger.debug("rerouting shards: [{}]", explanation); - rerouteService.reroute("disk threshold monitor", Priority.HIGH, ActionListener.wrap(reroutedClusterState -> { - - for (DiskUsage diskUsage : usagesOverHighThreshold) { - final RoutingNode routingNode = reroutedClusterState.getRoutingNodes().node(diskUsage.getNodeId()); - final DiskUsage usageIncludingRelocations; - final long relocatingShardsSize; - if (routingNode != null) { // might be temporarily null if the ClusterInfoService and the ClusterService are out of step - relocatingShardsSize = sizeOfRelocatingShards(routingNode, diskUsage, info, reroutedClusterState); - usageIncludingRelocations = new DiskUsage( - diskUsage.getNodeId(), - diskUsage.getNodeName(), - diskUsage.getPath(), - diskUsage.getTotalBytes(), - diskUsage.getFreeBytes() - relocatingShardsSize - ); - } else { - usageIncludingRelocations = diskUsage; - relocatingShardsSize = 0L; - } - final ByteSizeValue total = ByteSizeValue.ofBytes(usageIncludingRelocations.getTotalBytes()); - - if (usageIncludingRelocations.getFreeBytes() < diskThresholdSettings.getFreeBytesThresholdHighStage(total).getBytes()) { - nodesOverHighThresholdAndRelocating.remove(diskUsage.getNodeId()); - logger.warn( - "high disk watermark [{}] exceeded on {}, shards will be relocated away from this node; " - + "currently relocating away shards totalling [{}] bytes; the node is expected to continue to exceed " - + "the high disk watermark when these relocations are complete", - diskThresholdSettings.describeHighThreshold(total, false), - diskUsage, - -relocatingShardsSize - ); - } else if (nodesOverHighThresholdAndRelocating.add(diskUsage.getNodeId())) { - logger.info( - "high disk watermark [{}] exceeded on {}, shards will be relocated away from this node; " - + "currently relocating away shards totalling [{}] bytes; the node is expected to be below the high " - + "disk watermark when these relocations are complete", - diskThresholdSettings.describeHighThreshold(total, false), - diskUsage, - -relocatingShardsSize - ); - } else { - logger.debug( - "high disk watermark [{}] exceeded on {}, shards will be relocated away from this node; " - + "currently relocating away shards totalling [{}] bytes", - diskThresholdSettings.describeHighThreshold(total, false), - diskUsage, - -relocatingShardsSize - ); - } - } - - setLastRunTimeMillis(); - listener.onResponse(null); - }, e -> { - logger.debug("reroute failed", e); - setLastRunTimeMillis(); - listener.onFailure(e); - })); - } else { - logger.trace("no reroute required"); - listener.onResponse(null); - } - - // Generate a map of node name to ID so we can use it to look up node replacement targets - final Map nodeNameToId = state.getRoutingNodes() - .stream() - .collect(Collectors.toMap(rn -> rn.node().getName(), RoutingNode::nodeId, (s1, s2) -> s2)); + try (var asyncRefs = new RefCountingRunnable(this::checkFinished)) { + + if (reroute) { + logger.debug("rerouting shards: [{}]", explanation); + rerouteService.reroute( + "disk threshold monitor", + Priority.HIGH, + ActionListener.releaseAfter(ActionListener.runAfter(ActionListener.wrap(reroutedClusterState -> { + + for (DiskUsage diskUsage : usagesOverHighThreshold) { + final RoutingNode routingNode = reroutedClusterState.getRoutingNodes().node(diskUsage.getNodeId()); + final DiskUsage usageIncludingRelocations; + final long relocatingShardsSize; + if (routingNode != null) { // might be temporarily null if ClusterInfoService and ClusterService are out of step + relocatingShardsSize = sizeOfRelocatingShards(routingNode, diskUsage, info, reroutedClusterState); + usageIncludingRelocations = new DiskUsage( + diskUsage.getNodeId(), + diskUsage.getNodeName(), + diskUsage.getPath(), + diskUsage.getTotalBytes(), + diskUsage.getFreeBytes() - relocatingShardsSize + ); + } else { + usageIncludingRelocations = diskUsage; + relocatingShardsSize = 0L; + } + final ByteSizeValue total = ByteSizeValue.ofBytes(usageIncludingRelocations.getTotalBytes()); + + if (usageIncludingRelocations.getFreeBytes() < diskThresholdSettings.getFreeBytesThresholdHighStage(total) + .getBytes()) { + nodesOverHighThresholdAndRelocating.remove(diskUsage.getNodeId()); + logger.warn(""" + high disk watermark [{}] exceeded on {}, shards will be relocated away from this node; currently \ + relocating away shards totalling [{}] bytes; the node is expected to continue to exceed the high disk \ + watermark when these relocations are complete\ + """, diskThresholdSettings.describeHighThreshold(total, false), diskUsage, -relocatingShardsSize); + } else if (nodesOverHighThresholdAndRelocating.add(diskUsage.getNodeId())) { + logger.info(""" + high disk watermark [{}] exceeded on {}, shards will be relocated away from this node; currently \ + relocating away shards totalling [{}] bytes; the node is expected to be below the high disk watermark \ + when these relocations are complete\ + """, diskThresholdSettings.describeHighThreshold(total, false), diskUsage, -relocatingShardsSize); + } else { + logger.debug(""" + high disk watermark [{}] exceeded on {}, shards will be relocated away from this node; currently \ + relocating away shards totalling [{}] bytes\ + """, diskThresholdSettings.describeHighThreshold(total, false), diskUsage, -relocatingShardsSize); + } + } + }, e -> logger.debug("reroute failed", e)), this::setLastRunTimeMillis), asyncRefs.acquire()) + ); + } else { + logger.trace("no reroute required"); + } - // Calculate both the source node id and the target node id of a "replace" type shutdown - final Set nodesIdsPartOfReplacement = state.metadata() - .nodeShutdowns() - .values() - .stream() - .filter(meta -> meta.getType() == SingleNodeShutdownMetadata.Type.REPLACE) - .flatMap(meta -> Stream.of(meta.getNodeId(), nodeNameToId.get(meta.getTargetNodeName()))) - .collect(Collectors.toSet()); - - // Generate a set of all the indices that exist on either the target or source of a node replacement - final Set indicesOnReplaceSourceOrTarget = new HashSet<>(); - for (String nodeId : nodesIdsPartOfReplacement) { - for (ShardRouting shardRouting : state.getRoutingNodes().node(nodeId)) { - indicesOnReplaceSourceOrTarget.add(shardRouting.index().getName()); + // Generate a map of node name to ID so we can use it to look up node replacement targets + final Map nodeNameToId = state.getRoutingNodes() + .stream() + .collect(Collectors.toMap(rn -> rn.node().getName(), RoutingNode::nodeId, (s1, s2) -> s2)); + + // Calculate both the source node id and the target node id of a "replace" type shutdown + final Set nodesIdsPartOfReplacement = state.metadata() + .nodeShutdowns() + .values() + .stream() + .filter(meta -> meta.getType() == SingleNodeShutdownMetadata.Type.REPLACE) + .flatMap(meta -> Stream.of(meta.getNodeId(), nodeNameToId.get(meta.getTargetNodeName()))) + .collect(Collectors.toSet()); + + // Generate a set of all the indices that exist on either the target or source of a node replacement + final Set indicesOnReplaceSourceOrTarget = new HashSet<>(); + for (String nodeId : nodesIdsPartOfReplacement) { + for (ShardRouting shardRouting : state.getRoutingNodes().node(nodeId)) { + indicesOnReplaceSourceOrTarget.add(shardRouting.index().getName()); + } } - } - final Set indicesToAutoRelease = state.routingTable() - .indicesRouting() - .keySet() - .stream() - .filter(index -> indicesNotToAutoRelease.contains(index) == false) - .filter(index -> state.getBlocks().hasIndexBlock(index, IndexMetadata.INDEX_READ_ONLY_ALLOW_DELETE_BLOCK)) - // Do not auto release indices that are on either the source or the target of a node replacement - .filter(index -> indicesOnReplaceSourceOrTarget.contains(index) == false) - .collect(Collectors.toSet()); - - if (indicesToAutoRelease.isEmpty() == false) { - logger.info( - "releasing read-only block on indices " - + indicesToAutoRelease - + " since they are now allocated to nodes with sufficient disk space" - ); - updateIndicesReadOnly(indicesToAutoRelease, listener, false); - } else { - logger.trace("no auto-release required"); - listener.onResponse(null); - } + final Set indicesToAutoRelease = state.routingTable() + .indicesRouting() + .keySet() + .stream() + .filter(index -> indicesNotToAutoRelease.contains(index) == false) + .filter(index -> state.getBlocks().hasIndexBlock(index, IndexMetadata.INDEX_READ_ONLY_ALLOW_DELETE_BLOCK)) + // Do not auto release indices that are on either the source or the target of a node replacement + .filter(index -> indicesOnReplaceSourceOrTarget.contains(index) == false) + .collect(Collectors.toSet()); + + if (indicesToAutoRelease.isEmpty() == false) { + logger.info( + "releasing read-only block on indices " + + indicesToAutoRelease + + " since they are now allocated to nodes with sufficient disk space" + ); + updateIndicesReadOnly(indicesToAutoRelease, asyncRefs.acquire(), false); + } else { + logger.trace("no auto-release required"); + } - indicesToMarkReadOnly.removeIf(index -> state.getBlocks().indexBlocked(ClusterBlockLevel.WRITE, index)); - logger.trace("marking indices as read-only: [{}]", indicesToMarkReadOnly); - if (indicesToMarkReadOnly.isEmpty() == false) { - updateIndicesReadOnly(indicesToMarkReadOnly, listener, true); - } else { - listener.onResponse(null); + indicesToMarkReadOnly.removeIf(index -> state.getBlocks().indexBlocked(ClusterBlockLevel.WRITE, index)); + logger.trace("marking indices as read-only: [{}]", indicesToMarkReadOnly); + if (indicesToMarkReadOnly.isEmpty() == false) { + updateIndicesReadOnly(indicesToMarkReadOnly, asyncRefs.acquire(), true); + } } } @@ -453,23 +441,24 @@ private void setLastRunTimeMillis() { lastRunTimeMillis.getAndUpdate(l -> Math.max(l, currentTimeMillisSupplier.getAsLong())); } - protected void updateIndicesReadOnly(Set indicesToUpdate, ActionListener listener, boolean readOnly) { + protected void updateIndicesReadOnly(Set indicesToUpdate, Releasable onCompletion, boolean readOnly) { // set read-only block but don't block on the response - ActionListener wrappedListener = ActionListener.wrap(r -> { - setLastRunTimeMillis(); - listener.onResponse(r); - }, e -> { - logger.debug(() -> "setting indices [" + readOnly + "] read-only failed", e); - setLastRunTimeMillis(); - listener.onFailure(e); - }); Settings readOnlySettings = readOnly ? READ_ONLY_ALLOW_DELETE_SETTINGS : NOT_READ_ONLY_ALLOW_DELETE_SETTINGS; client.admin() .indices() .prepareUpdateSettings(indicesToUpdate.toArray(Strings.EMPTY_ARRAY)) .setSettings(readOnlySettings) .origin("disk-threshold-monitor") - .execute(wrappedListener.map(r -> null)); + .execute( + ActionListener.releaseAfter( + ActionListener.runAfter( + ActionListener.noop() + .delegateResponse((l, e) -> logger.debug(() -> "setting indices [" + readOnly + "] read-only failed", e)), + this::setLastRunTimeMillis + ), + onCompletion + ) + ); } private void removeExistingIndexBlocks() { diff --git a/server/src/main/java/org/elasticsearch/ingest/IngestService.java b/server/src/main/java/org/elasticsearch/ingest/IngestService.java index 3df1072c6861c..a6cf27446963c 100644 --- a/server/src/main/java/org/elasticsearch/ingest/IngestService.java +++ b/server/src/main/java/org/elasticsearch/ingest/IngestService.java @@ -22,7 +22,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.ingest.DeletePipelineRequest; import org.elasticsearch.action.ingest.PutPipelineRequest; -import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterChangedEvent; @@ -687,43 +687,46 @@ public void onFailure(Exception e) { @Override protected void doRun() { final Thread originalThread = Thread.currentThread(); - final ActionListener onFinished = new CountDownActionListener( - numberOfActionRequests, - () -> onCompletion.accept(originalThread, null) - ); + try (var refs = new RefCountingRunnable(() -> onCompletion.accept(originalThread, null))) { + int i = 0; + for (DocWriteRequest actionRequest : actionRequests) { + IndexRequest indexRequest = TransportBulkAction.getIndexWriteRequest(actionRequest); + if (indexRequest == null) { + i++; + continue; + } - int i = 0; - for (DocWriteRequest actionRequest : actionRequests) { - IndexRequest indexRequest = TransportBulkAction.getIndexWriteRequest(actionRequest); - if (indexRequest == null) { - onFinished.onResponse(null); - i++; - continue; - } + final String pipelineId = indexRequest.getPipeline(); + indexRequest.setPipeline(NOOP_PIPELINE_NAME); + final String finalPipelineId = indexRequest.getFinalPipeline(); + indexRequest.setFinalPipeline(NOOP_PIPELINE_NAME); + boolean hasFinalPipeline = true; + final List pipelines; + if (IngestService.NOOP_PIPELINE_NAME.equals(pipelineId) == false + && IngestService.NOOP_PIPELINE_NAME.equals(finalPipelineId) == false) { + pipelines = List.of(pipelineId, finalPipelineId); + } else if (IngestService.NOOP_PIPELINE_NAME.equals(pipelineId) == false) { + pipelines = List.of(pipelineId); + hasFinalPipeline = false; + } else if (IngestService.NOOP_PIPELINE_NAME.equals(finalPipelineId) == false) { + pipelines = List.of(finalPipelineId); + } else { + i++; + continue; + } + + executePipelines( + i, + pipelines.iterator(), + hasFinalPipeline, + indexRequest, + onDropped, + onFailure, + refs.acquireListener() + ); - final String pipelineId = indexRequest.getPipeline(); - indexRequest.setPipeline(NOOP_PIPELINE_NAME); - final String finalPipelineId = indexRequest.getFinalPipeline(); - indexRequest.setFinalPipeline(NOOP_PIPELINE_NAME); - boolean hasFinalPipeline = true; - final List pipelines; - if (IngestService.NOOP_PIPELINE_NAME.equals(pipelineId) == false - && IngestService.NOOP_PIPELINE_NAME.equals(finalPipelineId) == false) { - pipelines = List.of(pipelineId, finalPipelineId); - } else if (IngestService.NOOP_PIPELINE_NAME.equals(pipelineId) == false) { - pipelines = List.of(pipelineId); - hasFinalPipeline = false; - } else if (IngestService.NOOP_PIPELINE_NAME.equals(finalPipelineId) == false) { - pipelines = List.of(finalPipelineId); - } else { - onFinished.onResponse(null); i++; - continue; } - - executePipelines(i, pipelines.iterator(), hasFinalPipeline, indexRequest, onDropped, onFailure, onFinished); - - i++; } } }); diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index 11542919c13e5..f43e157266aba 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -31,6 +31,7 @@ import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.ListenableActionFuture; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateUpdateTask; @@ -965,31 +966,36 @@ private void doDeleteShardSnapshots( writeUpdatedRepoDataStep.whenComplete(updatedRepoData -> { listener.onRepositoryDataWritten(updatedRepoData); // Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion - final ActionListener afterCleanupsListener = new CountDownActionListener(2, listener::onDone); - cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, updatedRepoData, afterCleanupsListener); - asyncCleanupUnlinkedShardLevelBlobs( - repositoryData, - snapshotIds, - writeShardMetaDataAndComputeDeletesStep.result(), - afterCleanupsListener - ); + try (var refs = new RefCountingRunnable(listener::onDone)) { + cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, updatedRepoData, refs.acquireListener()); + asyncCleanupUnlinkedShardLevelBlobs( + repositoryData, + snapshotIds, + writeShardMetaDataAndComputeDeletesStep.result(), + refs.acquireListener() + ); + } }, listener::onFailure); } else { // Write the new repository data first (with the removed snapshot), using no shard generations final RepositoryData updatedRepoData = repositoryData.removeSnapshots(snapshotIds, ShardGenerations.EMPTY); writeIndexGen(updatedRepoData, repositoryStateId, repoMetaVersion, Function.identity(), ActionListener.wrap(newRepoData -> { - // Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion - final ActionListener afterCleanupsListener = new CountDownActionListener(2, () -> { + try (var refs = new RefCountingRunnable(() -> { listener.onRepositoryDataWritten(newRepoData); listener.onDone(); - }); - cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, newRepoData, afterCleanupsListener); - final StepListener> writeMetaAndComputeDeletesStep = new StepListener<>(); - writeUpdatedShardMetaDataAndComputeDeletes(snapshotIds, repositoryData, false, writeMetaAndComputeDeletesStep); - writeMetaAndComputeDeletesStep.whenComplete( - deleteResults -> asyncCleanupUnlinkedShardLevelBlobs(repositoryData, snapshotIds, deleteResults, afterCleanupsListener), - afterCleanupsListener::onFailure - ); + })) { + // Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion + cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, newRepoData, refs.acquireListener()); + writeUpdatedShardMetaDataAndComputeDeletes( + snapshotIds, + repositoryData, + false, + refs.acquireListener() + .delegateFailure( + (l, deleteResults) -> asyncCleanupUnlinkedShardLevelBlobs(repositoryData, snapshotIds, deleteResults, l) + ) + ); + } }, listener::onFailure)); } } diff --git a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java index 10ecef65c5e6f..e84f6e5e71e67 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java @@ -14,8 +14,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.StepListener; import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotRequest; -import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateApplier; @@ -250,8 +250,8 @@ public void restoreSnapshot( ) { try { // Try and fill in any missing repository UUIDs in case they're needed during the restore - final StepListener repositoryUuidRefreshListener = new StepListener<>(); - refreshRepositoryUuids(refreshRepositoryUuidOnRestore, repositoriesService, repositoryUuidRefreshListener); + final var repositoryUuidRefreshStep = new StepListener(); + refreshRepositoryUuids(refreshRepositoryUuidOnRestore, repositoriesService, () -> repositoryUuidRefreshStep.onResponse(null)); // Read snapshot info and metadata from the repository final String repositoryName = request.repository(); @@ -259,7 +259,7 @@ public void restoreSnapshot( final StepListener repositoryDataListener = new StepListener<>(); repository.getRepositoryData(repositoryDataListener); - repositoryDataListener.whenComplete(repositoryData -> repositoryUuidRefreshListener.whenComplete(ignored -> { + repositoryDataListener.whenComplete(repositoryData -> repositoryUuidRefreshStep.whenComplete(ignored -> { final String snapshotName = request.snapshot(); final Optional matchingSnapshotId = repositoryData.getSnapshotIds() .stream() @@ -493,59 +493,37 @@ private void setRefreshRepositoryUuidOnRestore(boolean refreshRepositoryUuidOnRe * * @param enabled If {@code false} this method completes the listener immediately * @param repositoriesService Supplies the repositories to check - * @param refreshListener Listener that is completed when all repositories have been refreshed. + * @param onCompletion Action that is executed when all repositories have been refreshed. */ // Exposed for tests - static void refreshRepositoryUuids(boolean enabled, RepositoriesService repositoriesService, ActionListener refreshListener) { - - if (enabled == false) { - logger.debug("repository UUID refresh is disabled"); - refreshListener.onResponse(null); - return; - } - - // We only care about BlobStoreRepositories because they're the only ones that can contain a searchable snapshot, and we only care - // about ones with missing UUIDs. It's possible to have the UUID change from under us if, e.g., the repository was wiped by an - // external force, but in this case any searchable snapshots are lost anyway so it doesn't really matter. - final List repositories = repositoriesService.getRepositories() - .values() - .stream() - .filter( - repository -> repository instanceof BlobStoreRepository - && repository.getMetadata().uuid().equals(RepositoryData.MISSING_UUID) - ) - .toList(); - if (repositories.isEmpty()) { - logger.debug("repository UUID refresh is not required"); - refreshListener.onResponse(null); - return; - } + static void refreshRepositoryUuids(boolean enabled, RepositoriesService repositoriesService, Runnable onCompletion) { + try (var refs = new RefCountingRunnable(onCompletion)) { + if (enabled == false) { + logger.debug("repository UUID refresh is disabled"); + return; + } - logger.info( - "refreshing repository UUIDs for repositories [{}]", - repositories.stream().map(repository -> repository.getMetadata().name()).collect(Collectors.joining(",")) - ); - final ActionListener countDownListener = new CountDownActionListener( - repositories.size(), - new ActionListener() { - @Override - public void onResponse(Void ignored) { - logger.debug("repository UUID refresh completed"); - refreshListener.onResponse(null); - } + for (Repository repository : repositoriesService.getRepositories().values()) { + // We only care about BlobStoreRepositories because they're the only ones that can contain a searchable snapshot, and we + // only care about ones with missing UUIDs. It's possible to have the UUID change from under us if, e.g., the repository was + // wiped by an external force, but in this case any searchable snapshots are lost anyway so it doesn't really matter. + if (repository instanceof BlobStoreRepository && repository.getMetadata().uuid().equals(RepositoryData.MISSING_UUID)) { + final var repositoryName = repository.getMetadata().name(); + logger.info("refreshing repository UUID for repository [{}]", repositoryName); + repository.getRepositoryData(ActionListener.releaseAfter(new ActionListener<>() { + @Override + public void onResponse(RepositoryData repositoryData) { + logger.debug(() -> format("repository UUID [{}] refresh completed", repositoryName)); + } - @Override - public void onFailure(Exception e) { - logger.debug("repository UUID refresh failed", e); - refreshListener.onResponse(null); // this refresh is best-effort, the restore should proceed either way + @Override + public void onFailure(Exception e) { + logger.debug(() -> format("repository UUID [{}] refresh failed", repositoryName), e); + } + }, refs.acquire())); } } - ).map(repositoryData -> null /* don't collect the RepositoryData */); - - for (Repository repository : repositories) { - repository.getRepositoryData(countDownListener); } - } private boolean isSystemIndex(IndexMetadata indexMetadata) { diff --git a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java index a43864a9938c0..ae9e43a96ebcb 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.Releasable; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -21,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -347,4 +349,85 @@ public void onFailure(Exception e) { mapped.onFailure(new IllegalStateException()); assertThat(exReference.get(), instanceOf(IllegalStateException.class)); } + + public void testReleasing() { + runReleasingTest(true); + runReleasingTest(false); + } + + private static void runReleasingTest(boolean successResponse) { + final AtomicBoolean releasedFlag = new AtomicBoolean(); + final ActionListener l = ActionListener.releasing(makeReleasable(releasedFlag)); + assertThat(l.toString(), containsString("release[test releasable]}")); + completeListener(successResponse, l); + assertTrue(releasedFlag.get()); + } + + private static void completeListener(boolean successResponse, ActionListener listener) { + if (successResponse) { + try { + listener.onResponse(null); + } catch (Exception e) { + // ok + } + } else { + listener.onFailure(new RuntimeException("simulated")); + } + } + + public void testReleaseAfter() { + runReleaseAfterTest(true, false); + runReleaseAfterTest(true, true); + runReleaseAfterTest(false, false); + } + + private static void runReleaseAfterTest(boolean successResponse, final boolean throwFromOnResponse) { + final AtomicBoolean released = new AtomicBoolean(); + final ActionListener l = ActionListener.releaseAfter(new ActionListener<>() { + @Override + public void onResponse(Void unused) { + if (throwFromOnResponse) { + throw new RuntimeException("onResponse"); + } + } + + @Override + public void onFailure(Exception e) { + // ok + } + + @Override + public String toString() { + return "test listener"; + } + }, makeReleasable(released)); + assertThat(l.toString(), containsString("test listener/release[test releasable]")); + + if (successResponse) { + try { + l.onResponse(null); + } catch (Exception e) { + // ok + } + } else { + l.onFailure(new RuntimeException("supplied")); + } + + assertTrue(released.get()); + } + + private static Releasable makeReleasable(AtomicBoolean releasedFlag) { + return new Releasable() { + @Override + public void close() { + assertTrue(releasedFlag.compareAndSet(false, true)); + } + + @Override + public String toString() { + return "test releasable"; + } + }; + } + } diff --git a/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java index 7655c2fd172f4..54b98e26b8e64 100644 --- a/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java @@ -115,7 +115,6 @@ public void onFailure(Exception e) { // can't use a null listener or runnable expectThrows(NullPointerException.class, () -> new CountDownActionListener(1, (ActionListener) null)); - expectThrows(NullPointerException.class, () -> new CountDownActionListener(1, (Runnable) null)); final int overage = randomIntBetween(1, 10); AtomicInteger assertionsTriggered = new AtomicInteger(); diff --git a/server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java b/server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java new file mode 100644 index 0000000000000..b5ccc4f50969b --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java @@ -0,0 +1,234 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.support; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.test.ESTestCase; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.elasticsearch.common.util.concurrent.EsExecutors.DIRECT_EXECUTOR_SERVICE; +import static org.hamcrest.Matchers.containsString; + +public class RefCountingRunnableTests extends ESTestCase { + + public void testBasicOperation() throws InterruptedException { + final var executed = new AtomicBoolean(); + final var threads = new Thread[between(0, 3)]; + boolean async = false; + final var startLatch = new CountDownLatch(1); + + try (var refs = new RefCountingRunnable(new Runnable() { + @Override + public void run() { + assertTrue(executed.compareAndSet(false, true)); + } + + @Override + public String toString() { + return "test runnable"; + } + })) { + assertEquals("refCounted[test runnable]", refs.toString()); + try (var ref = refs.acquire()) { + assertEquals("refCounted[test runnable]", ref.toString()); + } + var listener = refs.acquireListener(); + assertThat(listener.toString(), containsString("release[refCounted[test runnable]]")); + listener.onResponse(null); + + for (int i = 0; i < threads.length; i++) { + if (randomBoolean()) { + async = true; + var ref = refs.acquire(); + threads[i] = new Thread(() -> { + try (var ignored = ref) { + assertTrue(startLatch.await(10, TimeUnit.SECONDS)); + assertFalse(executed.get()); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + }); + } + } + + assertFalse(executed.get()); + } + + assertNotEquals(async, executed.get()); + + for (Thread thread : threads) { + if (thread != null) { + thread.start(); + } + } + + startLatch.countDown(); + + for (Thread thread : threads) { + if (thread != null) { + thread.join(); + } + } + + assertTrue(executed.get()); + } + + @SuppressWarnings("resource") + public void testNullCheck() { + expectThrows(NullPointerException.class, () -> new RefCountingRunnable(null)); + } + + public void testAsyncAcquire() throws InterruptedException { + final var completionLatch = new CountDownLatch(1); + final var executorService = EsExecutors.newScaling( + "test", + 0, + between(1, 10), + 10, + TimeUnit.SECONDS, + true, + EsExecutors.daemonThreadFactory("test"), + new ThreadContext(Settings.EMPTY) + ); + final var asyncPermits = new Semaphore(between(0, 1000)); + + try (var refs = new RefCountingRunnable(() -> { + assertEquals(1, completionLatch.getCount()); + completionLatch.countDown(); + })) { + class AsyncAction extends AbstractRunnable { + private final Releasable ref; + + AsyncAction(Releasable ref) { + this.ref = ref; + } + + @Override + protected void doRun() throws Exception { + if (asyncPermits.tryAcquire()) { + executorService.execute(new AsyncAction(refs.acquire())); + } + } + + @Override + public void onFailure(Exception e) { + assert e instanceof EsRejectedExecutionException esre && esre.isExecutorShutdown() : e; + } + + @Override + public void onAfter() { + ref.close(); + } + } + + for (int i = between(0, 5); i >= 0; i--) { + executorService.execute(new AsyncAction(refs.acquire())); + } + + assertEquals(1, completionLatch.getCount()); + } + + if (randomBoolean()) { + assertTrue(completionLatch.await(10, TimeUnit.SECONDS)); + assertFalse(asyncPermits.tryAcquire()); + } + + executorService.shutdown(); + assertTrue(executorService.awaitTermination(10, TimeUnit.SECONDS)); + + assertTrue(completionLatch.await(10, TimeUnit.SECONDS)); + } + + public void testValidation() { + final var callCount = new AtomicInteger(); + final var refs = new RefCountingRunnable(callCount::incrementAndGet); + refs.close(); + assertEquals(1, callCount.get()); + + for (int i = between(1, 5); i > 0; i--) { + final ThrowingRunnable throwingRunnable; + final String expectedMessage; + if (randomBoolean()) { + throwingRunnable = randomBoolean() ? refs::acquire : refs::acquireListener; + expectedMessage = RefCountingRunnable.ALREADY_CLOSED_MESSAGE; + } else { + throwingRunnable = refs::close; + expectedMessage = "invalid decRef call: already closed"; + } + + assertEquals(expectedMessage, expectThrows(AssertionError.class, throwingRunnable).getMessage()); + assertEquals(1, callCount.get()); + } + } + + public void testJavaDocExample() { + final var flag = new AtomicBoolean(); + runExample(() -> assertTrue(flag.compareAndSet(false, true))); + assertTrue(flag.get()); + } + + private void runExample(Runnable finalRunnable) { + final var collection = randomList(10, Object::new); + final var otherCollection = randomList(10, Object::new); + final var flag = randomBoolean(); + @SuppressWarnings("UnnecessaryLocalVariable") + final var executorService = DIRECT_EXECUTOR_SERVICE; + + try (var refs = new RefCountingRunnable(finalRunnable)) { + for (var item : collection) { + if (condition(item)) { + runAsyncAction(item, refs.acquire()); + } + } + if (flag) { + runOneOffAsyncAction(refs.acquire()); + return; + } + for (var item : otherCollection) { + var itemRef = refs.acquire(); // delays completion while the background action is pending + executorService.execute(() -> { + try (var ignored = itemRef) { + if (condition(item)) { + runOtherAsyncAction(item, refs.acquire()); + } + } + }); + } + } + } + + @SuppressWarnings("unused") + private boolean condition(Object item) { + return randomBoolean(); + } + + @SuppressWarnings("unused") + private void runAsyncAction(Object item, Releasable releasable) { + releasable.close(); + } + + @SuppressWarnings("unused") + private void runOtherAsyncAction(Object item, Releasable releasable) { + releasable.close(); + } + + private void runOneOffAsyncAction(Releasable releasable) { + releasable.close(); + } +} diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorTests.java index 1ac0313579992..5de084cc6caf4 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitorTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.test.MockLogAppender; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -121,10 +122,10 @@ private void doTestMarkFloodStageIndicesReadOnly(boolean testMaxHeadroom) { ) { @Override - protected void updateIndicesReadOnly(Set indicesToMarkReadOnly, ActionListener listener, boolean readOnly) { + protected void updateIndicesReadOnly(Set indicesToMarkReadOnly, Releasable onCompletion, boolean readOnly) { assertTrue(indices.compareAndSet(null, indicesToMarkReadOnly)); assertTrue(readOnly); - listener.onResponse(null); + onCompletion.close(); } }; @@ -216,10 +217,10 @@ protected void updateIndicesReadOnly(Set indicesToMarkReadOnly, ActionLi } ) { @Override - protected void updateIndicesReadOnly(Set indicesToMarkReadOnly, ActionListener listener, boolean readOnly) { + protected void updateIndicesReadOnly(Set indicesToMarkReadOnly, Releasable onCompletion, boolean readOnly) { assertTrue(indices.compareAndSet(null, indicesToMarkReadOnly)); assertTrue(readOnly); - listener.onResponse(null); + onCompletion.close(); } }; @@ -276,7 +277,7 @@ private void doTestDoesNotSubmitRerouteTaskTooFrequently(boolean testMaxHeadroom } ) { @Override - protected void updateIndicesReadOnly(Set indicesToMarkReadOnly, ActionListener listener, boolean readOnly) { + protected void updateIndicesReadOnly(Set indicesToMarkReadOnly, Releasable onCompletion, boolean readOnly) { throw new AssertionError("unexpected"); } }; @@ -472,13 +473,13 @@ private void doTestAutoReleaseIndices(boolean testMaxHeadroom) { } ) { @Override - protected void updateIndicesReadOnly(Set indicesToUpdate, ActionListener listener, boolean readOnly) { + protected void updateIndicesReadOnly(Set indicesToUpdate, Releasable onCompletion, boolean readOnly) { if (readOnly) { assertTrue(indicesToMarkReadOnly.compareAndSet(null, indicesToUpdate)); } else { assertTrue(indicesToRelease.compareAndSet(null, indicesToUpdate)); } - listener.onResponse(null); + onCompletion.close(); } }; indicesToMarkReadOnly.set(null); @@ -564,13 +565,13 @@ protected void updateIndicesReadOnly(Set indicesToUpdate, ActionListener } ) { @Override - protected void updateIndicesReadOnly(Set indicesToUpdate, ActionListener listener, boolean readOnly) { + protected void updateIndicesReadOnly(Set indicesToUpdate, Releasable onCompletion, boolean readOnly) { if (readOnly) { assertTrue(indicesToMarkReadOnly.compareAndSet(null, indicesToUpdate)); } else { assertTrue(indicesToRelease.compareAndSet(null, indicesToUpdate)); } - listener.onResponse(null); + onCompletion.close(); } }; // When free disk on any of node1 or node2 goes below the flood watermark, then apply index block on indices not having the block @@ -807,13 +808,13 @@ private void doTestNoAutoReleaseOfIndicesOnReplacementNodes(boolean testMaxHeadr } ) { @Override - protected void updateIndicesReadOnly(Set indicesToUpdate, ActionListener listener, boolean readOnly) { + protected void updateIndicesReadOnly(Set indicesToUpdate, Releasable onCompletion, boolean readOnly) { if (readOnly) { assertTrue(indicesToMarkReadOnly.compareAndSet(null, indicesToUpdate)); } else { assertTrue(indicesToRelease.compareAndSet(null, indicesToUpdate)); } - listener.onResponse(null); + onCompletion.close(); } }; indicesToMarkReadOnly.set(null); @@ -1051,8 +1052,8 @@ public long getAsLong() { (reason, priority, listener) -> listener.onResponse(clusterStateRef.get()) ) { @Override - protected void updateIndicesReadOnly(Set indicesToMarkReadOnly, ActionListener listener, boolean readOnly) { - listener.onResponse(null); + protected void updateIndicesReadOnly(Set indicesToMarkReadOnly, Releasable onCompletion, boolean readOnly) { + onCompletion.close(); } @Override diff --git a/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java index 10038993f4c74..f0aa256ff0317 100644 --- a/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java +++ b/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java @@ -11,7 +11,7 @@ import org.apache.lucene.store.ByteBuffersDirectory; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.Settings; @@ -68,13 +68,10 @@ public void setTaskRunner(ShardSnapshotTaskRunner taskRunner) { public void snapshotShard(SnapshotShardContext context) { int filesToUpload = randomIntBetween(0, 10); - if (filesToUpload == 0) { - finishedShardSnapshots.incrementAndGet(); - } else { - expectedFileSnapshotTasks.addAndGet(filesToUpload); - ActionListener uploadListener = new CountDownActionListener(filesToUpload, finishedShardSnapshots::incrementAndGet); + expectedFileSnapshotTasks.addAndGet(filesToUpload); + try (var refs = new RefCountingRunnable(finishedShardSnapshots::incrementAndGet)) { for (int i = 0; i < filesToUpload; i++) { - taskRunner.enqueueFileSnapshot(context, ShardSnapshotTaskRunnerTests::dummyFileInfo, uploadListener); + taskRunner.enqueueFileSnapshot(context, ShardSnapshotTaskRunnerTests::dummyFileInfo, refs.acquireListener()); } } finishedShardSnapshotTasks.incrementAndGet(); diff --git a/server/src/test/java/org/elasticsearch/snapshots/RestoreServiceTests.java b/server/src/test/java/org/elasticsearch/snapshots/RestoreServiceTests.java index cfd80d2b00cf1..a1f39c4f16be2 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/RestoreServiceTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/RestoreServiceTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotRequest; -import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.cluster.metadata.DataStreamTestHelper; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -33,7 +32,7 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -122,16 +121,14 @@ public void testPrefixNotChanged() { } public void testRefreshRepositoryUuidsDoesNothingIfDisabled() { - final PlainActionFuture listener = new PlainActionFuture<>(); final RepositoriesService repositoriesService = mock(RepositoriesService.class); - RestoreService.refreshRepositoryUuids(false, repositoriesService, listener); - assertTrue(listener.isDone()); + final AtomicBoolean called = new AtomicBoolean(); + RestoreService.refreshRepositoryUuids(false, repositoriesService, () -> assertTrue(called.compareAndSet(false, true))); + assertTrue(called.get()); verifyNoMoreInteractions(repositoriesService); } - public void testRefreshRepositoryUuidsRefreshesAsNeeded() throws Exception { - final PlainActionFuture listener = new PlainActionFuture<>(); - + public void testRefreshRepositoryUuidsRefreshesAsNeeded() { final int repositoryCount = between(1, 5); final Map repositories = Maps.newMapWithExpectedSize(repositoryCount); final Set pendingRefreshes = new HashSet<>(); @@ -177,8 +174,9 @@ public void testRefreshRepositoryUuidsRefreshesAsNeeded() throws Exception { final RepositoriesService repositoriesService = mock(RepositoriesService.class); when(repositoriesService.getRepositories()).thenReturn(repositories); - RestoreService.refreshRepositoryUuids(true, repositoriesService, listener); - assertNull(listener.get(0L, TimeUnit.SECONDS)); + final AtomicBoolean completed = new AtomicBoolean(); + RestoreService.refreshRepositoryUuids(true, repositoriesService, () -> assertTrue(completed.compareAndSet(false, true))); + assertTrue(completed.get()); assertThat(pendingRefreshes, empty()); finalAssertions.forEach(Runnable::run); } diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java index 5ef79116205dd..43030941e8028 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -32,7 +33,6 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; -import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.xcontent.StatusToXContentObject; import org.elasticsearch.core.TimeValue; import org.elasticsearch.repositories.RepositoriesService; @@ -46,9 +46,7 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ReceiveTimeoutTransportException; -import org.elasticsearch.transport.TransportException; import org.elasticsearch.transport.TransportRequestOptions; -import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentBuilder; @@ -362,8 +360,7 @@ public static class AsyncAction { private final Queue queue = ConcurrentCollections.newQueue(); private final AtomicReference failure = new AtomicReference<>(); private final Semaphore innerFailures = new Semaphore(5); // limit the number of suppressed failures - private final int workerCount; - private final CountDown workerCountdown; + private final RefCountingRunnable requestRefs = new RefCountingRunnable(this::runCleanUp); private final Set expectedBlobs = ConcurrentCollections.newConcurrentSet(); private final List responses; private final RepositoryPerformanceSummary.Builder summary = new RepositoryPerformanceSummary.Builder(); @@ -386,9 +383,6 @@ public AsyncAction( this.timeoutTimeMillis = currentTimeMillisSupplier.getAsLong() + request.getTimeout().millis(); this.listener = listener; - this.workerCount = request.getConcurrency(); - this.workerCountdown = new CountDown(workerCount); - responses = new ArrayList<>(request.blobCount); } @@ -440,7 +434,6 @@ private boolean isRunning() { public void run() { assert queue.isEmpty() : "must only run action once"; assert failure.get() == null : "must only run action once"; - assert workerCountdown.isCountedDown() == false : "must only run action once"; logger.info("running analysis of repository [{}] using path [{}]", request.getRepositoryName(), blobPath); @@ -477,8 +470,10 @@ && rarely(random) queue.add(verifyBlobTask); } - for (int i = 0; i < workerCount; i++) { - processNextTask(); + try (var ignored = requestRefs) { + for (int i = 0; i < request.getConcurrency(); i++) { + processNextTask(); + } } } @@ -488,9 +483,7 @@ private boolean rarely(Random random) { private void processNextTask() { final VerifyBlobTask thisTask = queue.poll(); - if (isRunning() == false || thisTask == null) { - onWorkerCompletion(); - } else { + if (isRunning() && thisTask != null) { logger.trace("processing [{}]", thisTask); // NB although all this is on the SAME thread, the per-blob verification runs on a SNAPSHOT thread so we don't have to worry // about local requests resulting in a stack overflow here @@ -503,9 +496,9 @@ private void processNextTask() { thisTask.request, task, transportRequestOptions, - new TransportResponseHandler() { + new ActionListenerResponseHandler<>(ActionListener.releaseAfter(new ActionListener<>() { @Override - public void handleResponse(BlobAnalyzeAction.Response response) { + public void onResponse(BlobAnalyzeAction.Response response) { logger.trace("finished [{}]", thisTask); if (thisTask.request.getAbortWrite() == false) { expectedBlobs.add(thisTask.request.getBlobName()); // each task cleans up its own mess on failure @@ -520,17 +513,11 @@ public void handleResponse(BlobAnalyzeAction.Response response) { } @Override - public void handleException(TransportException exp) { + public void onFailure(Exception exp) { logger.debug(() -> "failed [" + thisTask + "]", exp); fail(exp); - onWorkerCompletion(); } - - @Override - public BlobAnalyzeAction.Response read(StreamInput in) throws IOException { - return new BlobAnalyzeAction.Response(in); - } - } + }, requestRefs.acquire()), BlobAnalyzeAction.Response::new) ); } @@ -540,16 +527,14 @@ private BlobContainer getBlobContainer() { return repository.blobStore().blobContainer(repository.basePath().add(blobPath)); } - private void onWorkerCompletion() { - if (workerCountdown.countDown()) { - transportService.getThreadPool().executor(ThreadPool.Names.SNAPSHOT).execute(ActionRunnable.wrap(listener, l -> { - final long listingStartTimeNanos = System.nanoTime(); - ensureConsistentListing(); - final long deleteStartTimeNanos = System.nanoTime(); - deleteContainer(); - sendResponse(listingStartTimeNanos, deleteStartTimeNanos); - })); - } + private void runCleanUp() { + transportService.getThreadPool().executor(ThreadPool.Names.SNAPSHOT).execute(ActionRunnable.wrap(listener, l -> { + final long listingStartTimeNanos = System.nanoTime(); + ensureConsistentListing(); + final long deleteStartTimeNanos = System.nanoTime(); + deleteContainer(); + sendResponse(listingStartTimeNanos, deleteStartTimeNanos); + })); } private void ensureConsistentListing() { @@ -650,20 +635,7 @@ private void sendResponse(final long listingStartTimeNanos, final long deleteSta } } - private static class VerifyBlobTask { - final DiscoveryNode node; - final BlobAnalyzeAction.Request request; - - VerifyBlobTask(DiscoveryNode node, BlobAnalyzeAction.Request request) { - this.node = node; - this.request = request; - } - - @Override - public String toString() { - return "VerifyBlobTask{" + "node=" + node + ", request=" + request + '}'; - } - } + private record VerifyBlobTask(DiscoveryNode node, BlobAnalyzeAction.Request request) {} } public static class Request extends ActionRequest { From e6fdb183c592e9eb45377f12969f7e5f48530a5c Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 3 Jan 2023 12:45:19 +0000 Subject: [PATCH 2/5] Fixup --- .../elasticsearch/action/support/RefCountingRunnable.java | 6 +----- .../org/elasticsearch/cluster/NodeConnectionsService.java | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java b/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java index b78fe212386b4..323282ac96780 100644 --- a/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java +++ b/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java @@ -16,8 +16,6 @@ import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; -import java.util.Objects; - /** * A mechanism to trigger an action on the completion of some (dynamic) collection of other actions. Basic usage is as follows: * @@ -66,7 +64,6 @@ public final class RefCountingRunnable implements Releasable { private static final Logger logger = LogManager.getLogger(RefCountingRunnable.class); static final String ALREADY_CLOSED_MESSAGE = "already closed, cannot acquire or release any further refs"; - private final Runnable delegate; // TODO drop this when #92616 merged private final RefCounted refCounted; /** @@ -74,7 +71,6 @@ public final class RefCountingRunnable implements Releasable { * @param delegate The action to execute when all refs are released. This action must not throw any exception. */ public RefCountingRunnable(Runnable delegate) { - this.delegate = Objects.requireNonNull(delegate); this.refCounted = AbstractRefCounted.of(delegate); } @@ -114,7 +110,7 @@ public void close() { @Override public String toString() { - return "refCounted[" + delegate + "]"; // TODO refCounted.toString() when #92616 merged + return refCounted.toString(); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java index 1d55c5154f9e1..8809a883bd978 100644 --- a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java +++ b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java @@ -266,7 +266,7 @@ private synchronized void releaseListener() { } private void doConnect() { - //noinspection resource + // noinspection resource var refs = acquireRefs(); if (refs == null) { return; From 7422057570458b96bccea0171279716db222fcb3 Mon Sep 17 00:00:00 2001 From: David Turner Date: Thu, 5 Jan 2023 15:59:34 +0000 Subject: [PATCH 3/5] Javadocs --- .../elasticsearch/action/support/RefCountingRunnable.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java b/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java index 323282ac96780..2667236d68fd5 100644 --- a/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java +++ b/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java @@ -77,6 +77,8 @@ public RefCountingRunnable(Runnable delegate) { /** * Acquire a reference to this object and return an action which releases it. The delegate {@link Runnable} is called when all its * references have been released. + * + * Callers must take care to close the returned resource exactly once. This deviates from the contract of {@link java.io.Closeable}. */ public Releasable acquire() { if (refCounted.tryIncRef()) { @@ -96,7 +98,9 @@ public ActionListener acquireListener() { } /** - * Release a reference to this object, and execute the delegate {@link Runnable} if there are no other references. + * Release the original reference to this object, which executes the delegate {@link Runnable} if there are no other references. + * + * Callers must take care to close this resource exactly once. This deviates from the contract of {@link java.io.Closeable}. */ @Override public void close() { From 01e96180e5b4d7441504778e6ca75479d5fa106d Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 9 Jan 2023 09:06:02 +0000 Subject: [PATCH 4/5] Enforce one-shot closes as per Releasable contract --- .../action/support/RefCountingRunnable.java | 38 +++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java b/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java index 2667236d68fd5..0fea6dbeb1ad8 100644 --- a/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java +++ b/server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java @@ -16,6 +16,8 @@ import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; +import java.util.concurrent.atomic.AtomicBoolean; + /** * A mechanism to trigger an action on the completion of some (dynamic) collection of other actions. Basic usage is as follows: * @@ -65,6 +67,21 @@ public final class RefCountingRunnable implements Releasable { static final String ALREADY_CLOSED_MESSAGE = "already closed, cannot acquire or release any further refs"; private final RefCounted refCounted; + private final AtomicBoolean originalRefReleased = new AtomicBoolean(); + + private class AcquiredRef implements Releasable { + private final AtomicBoolean released = new AtomicBoolean(); + + @Override + public void close() { + releaseRef(released); + } + + @Override + public String toString() { + return RefCountingRunnable.this.toString(); + } + } /** * Construct a {@link RefCountingRunnable} which executes {@code delegate} when all refs are released. @@ -82,8 +99,7 @@ public RefCountingRunnable(Runnable delegate) { */ public Releasable acquire() { if (refCounted.tryIncRef()) { - // closing ourselves releases a ref, so we can just return 'this' and avoid any allocation; callers only see a Releasable - return this; + return new AcquiredRef(); } assert false : ALREADY_CLOSED_MESSAGE; throw new IllegalStateException(ALREADY_CLOSED_MESSAGE); @@ -104,11 +120,19 @@ public ActionListener acquireListener() { */ @Override public void close() { - try { - refCounted.decRef(); - } catch (Exception e) { - logger.error("exception in delegate", e); - assert false : e; + releaseRef(originalRefReleased); + } + + private void releaseRef(AtomicBoolean released) { + if (released.compareAndSet(false, true)) { + try { + refCounted.decRef(); + } catch (Exception e) { + logger.error("exception in delegate", e); + assert false : e; + } + } else { + assert false : "already closed"; } } From 94bcfc63b79ac86e3d84cca2d466dc15aacfb2dc Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 9 Jan 2023 09:34:51 +0000 Subject: [PATCH 5/5] Fix test --- .../elasticsearch/action/support/RefCountingRunnableTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java b/server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java index b5ccc4f50969b..41e9e4044024a 100644 --- a/server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java @@ -169,7 +169,7 @@ public void testValidation() { expectedMessage = RefCountingRunnable.ALREADY_CLOSED_MESSAGE; } else { throwingRunnable = refs::close; - expectedMessage = "invalid decRef call: already closed"; + expectedMessage = "already closed"; } assertEquals(expectedMessage, expectThrows(AssertionError.class, throwingRunnable).getMessage());