Skip to content
6 changes: 6 additions & 0 deletions docs/changelog/127371.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 127371
summary: Add cancellation support in `TransportGetAllocationStatsAction`
area: Allocation
type: feature
issues:
- 123248
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.SingleResultDeduplicator;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.master.MasterNodeReadRequest;
import org.elasticsearch.action.support.master.TransportMasterNodeReadAction;
Expand All @@ -35,15 +35,20 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAction<
TransportGetAllocationStatsAction.Request,
Expand All @@ -62,8 +67,10 @@ public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAc
);

private final AllocationStatsCache allocationStatsCache;
private final SingleResultDeduplicator<Map<String, NodeAllocationStats>> allocationStatsSupplier;
private final Consumer<ActionListener<Map<String, NodeAllocationStats>>> allocationStatsSupplier;
private final DiskThresholdSettings diskThresholdSettings;
private SubscribableListener<Response> waitingListeners;
private List<TaskListenerPair> tasksList;

@Inject
public TransportGetAllocationStatsAction(
Expand All @@ -87,19 +94,19 @@ public TransportGetAllocationStatsAction(
);
final var managementExecutor = threadPool.executor(ThreadPool.Names.MANAGEMENT);
this.allocationStatsCache = new AllocationStatsCache(threadPool, DEFAULT_CACHE_TTL);
this.allocationStatsSupplier = new SingleResultDeduplicator<>(threadPool.getThreadContext(), l -> {
this.allocationStatsSupplier = l -> {
final var cachedStats = allocationStatsCache.get();
if (cachedStats != null) {
l.onResponse(cachedStats);
return;
}

managementExecutor.execute(ActionRunnable.supply(l, () -> {
final var stats = allocationStatsService.stats();
final var stats = allocationStatsService.stats(this::ensureNotCancelled);
allocationStatsCache.put(stats);
return stats;
}));
});
};
this.diskThresholdSettings = new DiskThresholdSettings(clusterService.getSettings(), clusterService.getClusterSettings());
clusterService.getClusterSettings().initializeAndWatch(CACHE_TTL_SETTING, this.allocationStatsCache::setTTL);
}
Expand All @@ -118,13 +125,65 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) throws Exception {
// NB we are still on a transport thread here - if adding more functionality here make sure to fork to a different pool

final SubscribableListener<Map<String, NodeAllocationStats>> allocationStatsStep = request.metrics().contains(Metric.ALLOCATIONS)
? SubscribableListener.newForked(allocationStatsSupplier::execute)
: SubscribableListener.newSucceeded(Map.of());
if (request.metrics().contains(Metric.ALLOCATIONS) == false) {
listener.onResponse(statsToResponse(Map.of(), request));
return;
}
// Perform a cheap check for the cached stats up front.
final var cachedStats = allocationStatsCache.get();
if (cachedStats != null) {
listener.onResponse(statsToResponse(cachedStats, request));
return;
}

assert task instanceof CancellableTask;
final var wrappedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
final var taskListenerPair = new TaskListenerPair((CancellableTask) task, wrappedListener);

synchronized (this) {
if (waitingListeners != null) {
tasksList.add(taskListenerPair);
waitingListeners.addListener(wrappedListener);
return;
}

allocationStatsStep.andThenApply(
allocationStats -> new Response(allocationStats, request.metrics().contains(Metric.FS) ? diskThresholdSettings : null)
).addListener(listener);
tasksList = new ArrayList<>();
waitingListeners = new SubscribableListener<>();
tasksList.add(taskListenerPair);
waitingListeners.addListener(ActionListener.runBefore(wrappedListener, () -> {
synchronized (this) {
waitingListeners = null;
tasksList = null;
}
}));
}

SubscribableListener.newForked(allocationStatsSupplier::accept)
.andThenApply(stats -> statsToResponse(stats, request))
.addListener(waitingListeners);
}

private Response statsToResponse(Map<String, NodeAllocationStats> stats, Request request) {
return new Response(stats, request.metrics().contains(Metric.FS) ? diskThresholdSettings : null);
}

private void ensureNotCancelled() {
final int count;
synchronized (this) {
count = tasksList.size();
}
boolean allTasksCancelled = true;
// Check each task to give each task a chance to invoke their listener (once) when cancelled.
for (int i = 0; i < count; ++i) {
final TaskListenerPair taskPair;
synchronized (this) {
taskPair = tasksList.get(i);
}
allTasksCancelled &= taskPair.isCancelled();
}
if (allTasksCancelled) {
throw new TaskCancelledException("task cancelled");
}
}

@Override
Expand Down Expand Up @@ -167,6 +226,11 @@ public EnumSet<Metric> metrics() {
public ActionRequestValidationException validate() {
return null;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "", parentTaskId, headers);
}
}

public static class Response extends ActionResponse {
Expand Down Expand Up @@ -245,4 +309,24 @@ void put(Map<String, NodeAllocationStats> stats) {
}
}
}

private static class TaskListenerPair {
private final CancellableTask task;
private final ActionListener<Response> listener;
private boolean detectedCancellation;

TaskListenerPair(CancellableTask task, ActionListener<Response> listener) {
this.task = task;
this.listener = listener;
this.detectedCancellation = false;
}

boolean isCancelled() {
if (detectedCancellation == false && task.isCancelled()) {
detectedCancellation = true;
listener.onFailure(new TaskCancelledException("task cancelled"));
}
return task.isCancelled();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,22 @@ public AllocationStatsService(
* Returns a map of node IDs to node allocation stats.
*/
public Map<String, NodeAllocationStats> stats() {
return stats(() -> {});
}

/**
* Returns a map of node IDs to node allocation stats, promising to execute the provided {@link Runnable} during the computation to
* test for cancellation.
*/
public Map<String, NodeAllocationStats> stats(Runnable ensureNotCancelled) {
assert Transports.assertNotTransportThread("too expensive for a transport worker");

var clusterState = clusterService.state();
var nodesStatsAndWeights = nodeAllocationStatsAndWeightsCalculator.nodesAllocationStatsAndWeights(
clusterState.metadata(),
clusterState.getRoutingNodes(),
clusterInfoService.getClusterInfo(),
ensureNotCancelled,
desiredBalanceSupplier.get()
);
return nodesStatsAndWeights.entrySet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public Map<String, NodeAllocationStatsAndWeight> nodesAllocationStatsAndWeights(
Metadata metadata,
RoutingNodes routingNodes,
ClusterInfo clusterInfo,
Runnable ensureNotCancelled,
@Nullable DesiredBalance desiredBalance
) {
if (metadata.hasAnyIndices()) {
Expand All @@ -78,6 +79,7 @@ public Map<String, NodeAllocationStatsAndWeight> nodesAllocationStatsAndWeights(
long forecastedDiskUsage = 0;
long currentDiskUsage = 0;
for (ShardRouting shardRouting : node) {
ensureNotCancelled.run();
if (shardRouting.relocating()) {
// Skip the shard if it is moving off this node. The node running recovery will count it.
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ private void updateDesireBalanceMetrics(
routingAllocation.metadata(),
routingAllocation.routingNodes(),
routingAllocation.clusterInfo(),
() -> {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this into a static constant so that it has a name, something like NEVER_CANCELLED perhaps? Otherwise it leaves the reader wondering what this lambda is for (and also saves allocating a fresh lambda each time it's called, although the compiler might be clever enough to skip that anyway)

desiredBalance
);
Map<DiscoveryNode, NodeAllocationStatsAndWeightsCalculator.NodeAllocationStatsAndWeight> filteredNodeAllocationStatsAndWeights =
Expand Down
Loading
Loading