Skip to content
6 changes: 6 additions & 0 deletions docs/changelog/127371.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 127371
summary: Add cancellation support in `TransportGetAllocationStatsAction`
area: Allocation
type: feature
issues:
- 123248
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.SingleResultDeduplicator;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.SubscribableListener;
Expand All @@ -31,10 +30,12 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.util.CancellableSingleObjectCache;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -43,7 +44,8 @@
import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.Executor;
import java.util.function.BooleanSupplier;

public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAction<
TransportGetAllocationStatsAction.Request,
Expand All @@ -62,7 +64,6 @@ public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAc
);

private final AllocationStatsCache allocationStatsCache;
private final SingleResultDeduplicator<Map<String, NodeAllocationStats>> allocationStatsSupplier;
private final DiskThresholdSettings diskThresholdSettings;

@Inject
Expand All @@ -85,21 +86,7 @@ public TransportGetAllocationStatsAction(
// very cheaply.
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
final var managementExecutor = threadPool.executor(ThreadPool.Names.MANAGEMENT);
this.allocationStatsCache = new AllocationStatsCache(threadPool, DEFAULT_CACHE_TTL);
this.allocationStatsSupplier = new SingleResultDeduplicator<>(threadPool.getThreadContext(), l -> {
final var cachedStats = allocationStatsCache.get();
if (cachedStats != null) {
l.onResponse(cachedStats);
return;
}

managementExecutor.execute(ActionRunnable.supply(l, () -> {
final var stats = allocationStatsService.stats();
allocationStatsCache.put(stats);
return stats;
}));
});
this.allocationStatsCache = new AllocationStatsCache(threadPool, allocationStatsService, DEFAULT_CACHE_TTL);
this.diskThresholdSettings = new DiskThresholdSettings(clusterService.getSettings(), clusterService.getClusterSettings());
clusterService.getClusterSettings().initializeAndWatch(CACHE_TTL_SETTING, this.allocationStatsCache::setTTL);
}
Expand All @@ -118,8 +105,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) throws Exception {
// NB we are still on a transport thread here - if adding more functionality here make sure to fork to a different pool

assert task instanceof CancellableTask;
final var cancellableTask = (CancellableTask) task;

final SubscribableListener<Map<String, NodeAllocationStats>> allocationStatsStep = request.metrics().contains(Metric.ALLOCATIONS)
? SubscribableListener.newForked(allocationStatsSupplier::execute)
? SubscribableListener.newForked(l -> allocationStatsCache.get(cancellableTask::isCancelled, l))
: SubscribableListener.newSucceeded(Map.of());

allocationStatsStep.andThenApply(
Expand Down Expand Up @@ -167,6 +157,11 @@ public EnumSet<Metric> metrics() {
public ActionRequestValidationException validate() {
return null;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "", parentTaskId, headers);
}
}

public static class Response extends ActionResponse {
Expand Down Expand Up @@ -209,39 +204,60 @@ public DiskThresholdSettings getDiskThresholdSettings() {
}
}

private record CachedAllocationStats(Map<String, NodeAllocationStats> stats, long timestampMillis) {}

private static class AllocationStatsCache {
private static class AllocationStatsCache extends CancellableSingleObjectCache<Long, Long, Map<String, NodeAllocationStats>> {
private volatile long ttlMillis;
private final ThreadPool threadPool;
private final AtomicReference<CachedAllocationStats> cachedStats;
private final Executor executor;
private final AllocationStatsService allocationStatsService;

AllocationStatsCache(ThreadPool threadPool, TimeValue ttl) {
AllocationStatsCache(ThreadPool threadPool, AllocationStatsService allocationStatsService, TimeValue ttl) {
super(threadPool.getThreadContext());
this.threadPool = threadPool;
this.cachedStats = new AtomicReference<>();
this.executor = threadPool.executor(ThreadPool.Names.MANAGEMENT);
this.allocationStatsService = allocationStatsService;
setTTL(ttl);
}

void setTTL(TimeValue ttl) {
ttlMillis = ttl.millis();
if (ttlMillis == 0L) {
cachedStats.set(null);
}
clearCacheIfDisabled();
}

Map<String, NodeAllocationStats> get() {
if (ttlMillis == 0L) {
return null;
void get(BooleanSupplier isCancelled, ActionListener<Map<String, NodeAllocationStats>> listener) {
get(threadPool.relativeTimeInMillis(), isCancelled, listener);
}

@Override
protected void refresh(
Long aLong,
Runnable ensureNotCancelled,
BooleanSupplier supersedeIfStale,
ActionListener<Map<String, NodeAllocationStats>> listener
) {
if (supersedeIfStale.getAsBoolean() == false) {
executor.execute(
ActionRunnable.supply(
// If caching is disabled the item is only cached long enough to prevent duplicate concurrent requests.
ActionListener.runBefore(listener, this::clearCacheIfDisabled),
() -> allocationStatsService.stats(ensureNotCancelled)
)
);
}
}

// We don't set the atomic ref to null here upon expiration since we know it is about to be replaced with a fresh instance.
final var stats = cachedStats.get();
return stats == null || threadPool.relativeTimeInMillis() - stats.timestampMillis > ttlMillis ? null : stats.stats;
@Override
protected Long getKey(Long timestampMillis) {
return timestampMillis;
}

@Override
protected boolean isFresh(Long currentKey, Long newKey) {
return ttlMillis == 0 || newKey - currentKey <= ttlMillis;
}

void put(Map<String, NodeAllocationStats> stats) {
if (ttlMillis > 0L) {
cachedStats.set(new CachedAllocationStats(stats, threadPool.relativeTimeInMillis()));
private void clearCacheIfDisabled() {
if (ttlMillis == 0) {
clearCurrentCachedItem();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,22 @@ public AllocationStatsService(
* Returns a map of node IDs to node allocation stats.
*/
public Map<String, NodeAllocationStats> stats() {
return stats(() -> {});
}

/**
* Returns a map of node IDs to node allocation stats, promising to execute the provided {@link Runnable} during the computation to
* test for cancellation.
*/
public Map<String, NodeAllocationStats> stats(Runnable ensureNotCancelled) {
assert Transports.assertNotTransportThread("too expensive for a transport worker");

var clusterState = clusterService.state();
var nodesStatsAndWeights = nodeAllocationStatsAndWeightsCalculator.nodesAllocationStatsAndWeights(
clusterState.metadata(),
clusterState.getRoutingNodes(),
clusterInfoService.getClusterInfo(),
ensureNotCancelled,
desiredBalanceSupplier.get()
);
return nodesStatsAndWeights.entrySet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public Map<String, NodeAllocationStatsAndWeight> nodesAllocationStatsAndWeights(
Metadata metadata,
RoutingNodes routingNodes,
ClusterInfo clusterInfo,
Runnable ensureNotCancelled,
@Nullable DesiredBalance desiredBalance
) {
if (metadata.hasAnyIndices()) {
Expand All @@ -78,6 +79,7 @@ public Map<String, NodeAllocationStatsAndWeight> nodesAllocationStatsAndWeights(
long forecastedDiskUsage = 0;
long currentDiskUsage = 0;
for (ShardRouting shardRouting : node) {
ensureNotCancelled.run();
if (shardRouting.relocating()) {
// Skip the shard if it is moving off this node. The node running recovery will count it.
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ public void resetDesiredBalance() {
resetCurrentDesiredBalance = true;
}

/**
* Used as the argument for the {@code ensureNotCancelled} {@code Runnable} when calling the
* {@code nodeAllocationStatsAndWeightsCalculator} since there is no cancellation mechanism when called from
* {@code updateDesireBalanceMetrics()}.
*/
private static final Runnable NEVER_CANCELLED = () -> {};

private void updateDesireBalanceMetrics(
DesiredBalance desiredBalance,
RoutingAllocation routingAllocation,
Expand All @@ -400,6 +407,7 @@ private void updateDesireBalanceMetrics(
routingAllocation.metadata(),
routingAllocation.routingNodes(),
routingAllocation.clusterInfo(),
NEVER_CANCELLED,
desiredBalance
);
Map<DiscoveryNode, NodeAllocationStatsAndWeightsCalculator.NodeAllocationStatsAndWeight> filteredNodeAllocationStatsAndWeights =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ protected boolean isFresh(Key currentKey, Key newKey) {
return currentKey.equals(newKey);
}

/**
* Sets the currently cached item reference to {@code null}, which will result in a {@code refresh()} on the next {@code get()} call.
*/
protected final void clearCurrentCachedItem() {
this.currentCachedItemRef.set(null);
}

/**
* Start a retrieval for the value associated with the given {@code input}, and pass it to the given {@code listener}.
* <p>
Expand All @@ -110,7 +117,8 @@ protected boolean isFresh(Key currentKey, Key newKey) {
*
* @param input The input to compute the desired value, converted to a {@link Key} to determine if the value that's currently
* cached or pending is fresh enough.
* @param isCancelled Returns {@code true} if the listener no longer requires the value being computed.
* @param isCancelled Returns {@code true} if the listener no longer requires the value being computed. The listener is expected to be
* completed as soon as possible when cancellation is detected.
* @param listener The listener to notify when the desired value becomes available.
*/
public final void get(Input input, BooleanSupplier isCancelled, ActionListener<Value> listener) {
Expand Down Expand Up @@ -230,11 +238,15 @@ boolean addListener(ActionListener<Value> listener, BooleanSupplier isCancelled)
ActionListener.completeWith(listener, future::actionResult);
} else {
// Refresh is still pending; it's not cancelled because there are still references.
future.addListener(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext));
final var cancellableListener = ActionListener.notifyOnce(
ContextPreservingActionListener.wrapPreservingContext(listener, threadContext)
);
future.addListener(cancellableListener);
final AtomicBoolean released = new AtomicBoolean();
cancellationChecks.add(() -> {
if (released.get() == false && isCancelled.getAsBoolean() && released.compareAndSet(false, true)) {
decRef();
cancellableListener.onFailure(new TaskCancelledException("task cancelled"));
}
});
}
Expand Down
Loading
Loading