From 620f23cad88d97deae7178ef423c1d533747981c Mon Sep 17 00:00:00 2001 From: Jeremy Dahlgren Date: Thu, 24 Apr 2025 22:47:09 -0400 Subject: [PATCH 1/7] Add cancellation support in TransportGetAllocationStatsAction Closes #123248 --- .../TransportGetAllocationStatsAction.java | 106 +++++++++-- .../allocation/AllocationStatsService.java | 9 + ...deAllocationStatsAndWeightsCalculator.java | 2 + .../DesiredBalanceShardsAllocator.java | 1 + ...ransportGetAllocationStatsActionTests.java | 172 ++++++++++++++---- .../AllocationStatsServiceTests.java | 6 +- .../cluster/ESAllocationTestCase.java | 2 + 7 files changed, 252 insertions(+), 46 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java index eecbb3525bda9..de2e3e7758bd5 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java @@ -15,9 +15,9 @@ import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; -import org.elasticsearch.action.SingleResultDeduplicator; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.master.MasterNodeReadRequest; import org.elasticsearch.action.support.master.TransportMasterNodeReadAction; @@ -35,15 +35,20 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.io.IOException; +import java.util.ArrayList; import java.util.EnumSet; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAction< TransportGetAllocationStatsAction.Request, @@ -62,8 +67,10 @@ public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAc ); private final AllocationStatsCache allocationStatsCache; - private final SingleResultDeduplicator> allocationStatsSupplier; + private final Consumer>> allocationStatsSupplier; private final DiskThresholdSettings diskThresholdSettings; + private SubscribableListener waitingListeners; + private List tasksList; @Inject public TransportGetAllocationStatsAction( @@ -87,7 +94,7 @@ public TransportGetAllocationStatsAction( ); final var managementExecutor = threadPool.executor(ThreadPool.Names.MANAGEMENT); this.allocationStatsCache = new AllocationStatsCache(threadPool, DEFAULT_CACHE_TTL); - this.allocationStatsSupplier = new SingleResultDeduplicator<>(threadPool.getThreadContext(), l -> { + this.allocationStatsSupplier = l -> { final var cachedStats = allocationStatsCache.get(); if (cachedStats != null) { l.onResponse(cachedStats); @@ -95,11 +102,11 @@ public TransportGetAllocationStatsAction( } managementExecutor.execute(ActionRunnable.supply(l, () -> { - final var stats = allocationStatsService.stats(); + final var stats = allocationStatsService.stats(this::ensureNotCancelled); allocationStatsCache.put(stats); return stats; })); - }); + }; this.diskThresholdSettings = new DiskThresholdSettings(clusterService.getSettings(), clusterService.getClusterSettings()); clusterService.getClusterSettings().initializeAndWatch(CACHE_TTL_SETTING, this.allocationStatsCache::setTTL); } @@ -118,13 +125,65 @@ protected void doExecute(Task task, Request request, ActionListener li protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { // NB we are still on a transport thread here - if adding more functionality here make sure to fork to a different pool - final SubscribableListener> allocationStatsStep = request.metrics().contains(Metric.ALLOCATIONS) - ? SubscribableListener.newForked(allocationStatsSupplier::execute) - : SubscribableListener.newSucceeded(Map.of()); + if (request.metrics().contains(Metric.ALLOCATIONS) == false) { + listener.onResponse(statsToResponse(Map.of(), request)); + return; + } + // Perform a cheap check for the cached stats up front. + final var cachedStats = allocationStatsCache.get(); + if (cachedStats != null) { + listener.onResponse(statsToResponse(cachedStats, request)); + return; + } + + assert task instanceof CancellableTask; + final var wrappedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()); + final var taskListenerPair = new TaskListenerPair((CancellableTask) task, wrappedListener); + + synchronized (this) { + if (waitingListeners != null) { + tasksList.add(taskListenerPair); + waitingListeners.addListener(wrappedListener); + return; + } - allocationStatsStep.andThenApply( - allocationStats -> new Response(allocationStats, request.metrics().contains(Metric.FS) ? diskThresholdSettings : null) - ).addListener(listener); + tasksList = new ArrayList<>(); + waitingListeners = new SubscribableListener<>(); + tasksList.add(taskListenerPair); + waitingListeners.addListener(ActionListener.runBefore(wrappedListener, () -> { + synchronized (this) { + waitingListeners = null; + tasksList = null; + } + })); + } + + SubscribableListener.newForked(allocationStatsSupplier::accept) + .andThenApply(stats -> statsToResponse(stats, request)) + .addListener(waitingListeners); + } + + private Response statsToResponse(Map stats, Request request) { + return new Response(stats, request.metrics().contains(Metric.FS) ? diskThresholdSettings : null); + } + + private void ensureNotCancelled() { + final int count; + synchronized (this) { + count = tasksList.size(); + } + boolean allTasksCancelled = true; + // Check each task to give each task a chance to invoke their listener (once) when cancelled. + for (int i = 0; i < count; ++i) { + final TaskListenerPair taskPair; + synchronized (this) { + taskPair = tasksList.get(i); + } + allTasksCancelled &= taskPair.isCancelled(); + } + if (allTasksCancelled) { + throw new TaskCancelledException("task cancelled"); + } } @Override @@ -167,6 +226,11 @@ public EnumSet metrics() { public ActionRequestValidationException validate() { return null; } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, "", parentTaskId, headers); + } } public static class Response extends ActionResponse { @@ -245,4 +309,24 @@ void put(Map stats) { } } } + + private static class TaskListenerPair { + private final CancellableTask task; + private final ActionListener listener; + private boolean detectedCancellation; + + TaskListenerPair(CancellableTask task, ActionListener listener) { + this.task = task; + this.listener = listener; + this.detectedCancellation = false; + } + + boolean isCancelled() { + if (detectedCancellation == false && task.isCancelled()) { + detectedCancellation = true; + listener.onFailure(new TaskCancelledException("task cancelled")); + } + return task.isCancelled(); + } + } } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java index 926a6926c9aea..f31ddd36a2e31 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java @@ -47,6 +47,14 @@ public AllocationStatsService( * Returns a map of node IDs to node allocation stats. */ public Map stats() { + return stats(() -> {}); + } + + /** + * Returns a map of node IDs to node allocation stats, promising to execute the provided {@link Runnable} during the computation to + * test for cancellation. + */ + public Map stats(Runnable ensureNotCancelled) { assert Transports.assertNotTransportThread("too expensive for a transport worker"); var clusterState = clusterService.state(); @@ -54,6 +62,7 @@ public Map stats() { clusterState.metadata(), clusterState.getRoutingNodes(), clusterInfoService.getClusterInfo(), + ensureNotCancelled, desiredBalanceSupplier.get() ); return nodesStatsAndWeights.entrySet() diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/NodeAllocationStatsAndWeightsCalculator.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/NodeAllocationStatsAndWeightsCalculator.java index f85b97125ceba..d1c5f7df00fca 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/NodeAllocationStatsAndWeightsCalculator.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/NodeAllocationStatsAndWeightsCalculator.java @@ -54,6 +54,7 @@ public Map nodesAllocationStatsAndWeights( Metadata metadata, RoutingNodes routingNodes, ClusterInfo clusterInfo, + Runnable ensureNotCancelled, @Nullable DesiredBalance desiredBalance ) { if (metadata.hasAnyIndices()) { @@ -78,6 +79,7 @@ public Map nodesAllocationStatsAndWeights( long forecastedDiskUsage = 0; long currentDiskUsage = 0; for (ShardRouting shardRouting : node) { + ensureNotCancelled.run(); if (shardRouting.relocating()) { // Skip the shard if it is moving off this node. The node running recovery will count it. continue; diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java index e8d8d509282ab..d70802f0b64f9 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java @@ -400,6 +400,7 @@ private void updateDesireBalanceMetrics( routingAllocation.metadata(), routingAllocation.routingNodes(), routingAllocation.clusterInfo(), + () -> {}, desiredBalance ); Map filteredNodeAllocationStatsAndWeights = diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java index d60ac5ca47f6d..b37bc92c4729c 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.routing.allocation.AllocationStatsService; import org.elasticsearch.cluster.routing.allocation.NodeAllocationStats; @@ -23,7 +24,9 @@ import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.TimeValue; import org.elasticsearch.node.Node; -import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.TaskCancelHelper; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.telemetry.metric.MeterRegistry; import org.elasticsearch.test.ClusterServiceUtils; @@ -34,7 +37,9 @@ import org.elasticsearch.transport.TransportService; import org.junit.After; import org.junit.Before; +import org.mockito.ArgumentCaptor; +import java.util.Arrays; import java.util.EnumSet; import java.util.List; import java.util.Map; @@ -48,6 +53,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.not; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -126,26 +132,22 @@ public void testReturnsOnlyRequestedStats() throws Exception { EnumSet.allOf(Metric.class), EnumSet.copyOf(randomSubsetOf(between(1, Metric.values().length), EnumSet.allOf(Metric.class))) )) { - var request = new TransportGetAllocationStatsAction.Request( - TimeValue.ONE_MINUTE, - new TaskId(randomIdentifier(), randomNonNegativeLong()), - metrics - ); + var request = new TransportGetAllocationStatsAction.Request(TimeValue.ONE_MINUTE, TaskId.EMPTY_TASK_ID, metrics); - when(allocationStatsService.stats()).thenReturn( + when(allocationStatsService.stats(argThat(r -> true))).thenReturn( Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats()) ); var future = new PlainActionFuture(); - action.masterOperation(mock(Task.class), request, ClusterState.EMPTY_STATE, future); + action.masterOperation(getTask(), request, ClusterState.EMPTY_STATE, future); var response = future.get(); if (metrics.contains(Metric.ALLOCATIONS)) { assertThat(response.getNodeAllocationStats(), not(anEmptyMap())); - verify(allocationStatsService, times(++expectedNumberOfStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(++expectedNumberOfStatsServiceCalls); } else { assertThat(response.getNodeAllocationStats(), anEmptyMap()); - verify(allocationStatsService, times(expectedNumberOfStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(expectedNumberOfStatsServiceCalls); } if (metrics.contains(Metric.FS)) { @@ -160,7 +162,7 @@ public void testDeduplicatesStatsComputations() throws InterruptedException { disableAllocationStatsCache(); final var requestCounter = new AtomicInteger(); final var isExecuting = new AtomicBoolean(); - when(allocationStatsService.stats()).thenAnswer(invocation -> { + when(allocationStatsService.stats(argThat(r -> true))).thenAnswer(invocation -> { try { assertTrue(isExecuting.compareAndSet(false, true)); assertThat(Thread.currentThread().getName(), containsString("[management]")); @@ -180,16 +182,7 @@ public void testDeduplicatesStatsComputations() throws InterruptedException { final var minRequestIndex = requestCounter.get(); final TransportGetAllocationStatsAction.Response response = safeAwait( - l -> action.masterOperation( - mock(Task.class), - new TransportGetAllocationStatsAction.Request( - TEST_REQUEST_TIMEOUT, - TaskId.EMPTY_TASK_ID, - EnumSet.of(Metric.ALLOCATIONS) - ), - ClusterState.EMPTY_STATE, - l - ) + l -> action.masterOperation(getTask(), getRequest(), ClusterState.EMPTY_STATE, l) ); final var requestIndex = Integer.valueOf(response.getNodeAllocationStats().keySet().iterator().next()); @@ -203,6 +196,111 @@ public void testDeduplicatesStatsComputations() throws InterruptedException { } } + public void testAllTasksCancelledStopsComputationSingleThread() throws InterruptedException { + runAllTasksCancelledStopsComputationTestForNumThreads(1, false); + } + + public void testAllTasksCancelledStopsComputationMultipleThreads() throws InterruptedException { + runAllTasksCancelledStopsComputationTestForNumThreads(between(2, 10), false); + } + + public void testAllTasksCancelledStopsComputationSingleThreadCacheDisabled() throws InterruptedException { + runAllTasksCancelledStopsComputationTestForNumThreads(1, true); + } + + public void testAllTasksCancelledStopsComputationMultipleThreadsCacheDisabled() throws InterruptedException { + runAllTasksCancelledStopsComputationTestForNumThreads(between(2, 10), true); + } + + private void runAllTasksCancelledStopsComputationTestForNumThreads(final int numThreads, final boolean cacheDisabled) + throws InterruptedException { + if (cacheDisabled) { + disableAllocationStatsCache(); + } + final var isExecuting = new AtomicBoolean(); + final var ensureNotCancelledCaptor = ArgumentCaptor.forClass(Runnable.class); + final var tasks = new CancellableTask[numThreads]; + + when(allocationStatsService.stats(ensureNotCancelledCaptor.capture())).thenAnswer(invocation -> { + try { + assertTrue(isExecuting.compareAndSet(false, true)); + Arrays.stream(tasks).forEach(task -> TaskCancelHelper.cancel(task, "cancelled")); + ensureNotCancelledCaptor.getValue().run(); + fail("expected computation to stop when all tasks are cancelled"); + return null; + } finally { + Thread.yield(); + assertTrue(isExecuting.compareAndSet(true, false)); + } + }); + + ESTestCase.startInParallel(numThreads, threadNumber -> { + tasks[threadNumber] = getTask(); + final SubscribableListener listener = SubscribableListener.newForked( + l -> action.masterOperation(tasks[threadNumber], getRequest(), ClusterState.EMPTY_STATE, l) + ); + safeAwaitFailure(listener); + }); + } + + public void testRunSomeTasksCancelledForSingleThread() throws InterruptedException { + runSomeTasksCancelledForNumThreads(1, false); + } + + public void testRunSomeTasksCancelledForMultipleThreads() throws InterruptedException { + runSomeTasksCancelledForNumThreads(between(2, 10), false); + } + + public void testRunSomeTasksCancelledForSingleThreadCacheDisabled() throws InterruptedException { + runSomeTasksCancelledForNumThreads(1, true); + } + + public void testRunSomeTasksCancelledForMultipleThreadsCacheDisabled() throws InterruptedException { + runSomeTasksCancelledForNumThreads(between(2, 10), true); + } + + private void runSomeTasksCancelledForNumThreads(final int numThreads, final boolean cacheDisabled) throws InterruptedException { + if (cacheDisabled) { + disableAllocationStatsCache(); + } + final var isExecuting = new AtomicBoolean(); + final var ensureNotCancelledCaptor = ArgumentCaptor.forClass(Runnable.class); + final var tasks = new CancellableTask[numThreads]; + final var cancellations = new boolean[numThreads]; + final var stats = Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats()); + + when(allocationStatsService.stats(ensureNotCancelledCaptor.capture())).thenAnswer(invocation -> { + try { + assertTrue(isExecuting.compareAndSet(false, true)); + for (int i = 0; i < numThreads; ++i) { + if (cancellations[i]) { + TaskCancelHelper.cancel(tasks[i], "cancelled"); + } + } + ensureNotCancelledCaptor.getValue().run(); + return stats; + } finally { + Thread.yield(); + assertTrue(isExecuting.compareAndSet(true, false)); + } + }); + + ESTestCase.startInParallel(numThreads, threadNumber -> { + tasks[threadNumber] = getTask(); + cancellations[threadNumber] = threadNumber > 0; + final SubscribableListener listener = SubscribableListener.newForked( + l -> action.masterOperation(tasks[threadNumber], getRequest(), ClusterState.EMPTY_STATE, l) + ); + listener.addListener(ActionListener.wrap(response -> { assertSame(stats, response.getNodeAllocationStats()); }, e -> { + if (e instanceof TaskCancelledException) { + assertTrue("got an unexpected cancellation exception for thread " + threadNumber, cancellations[threadNumber]); + } else { + fail(e); + } + })); + }); + } + public void testGetStatsWithCachingEnabled() throws Exception { final AtomicReference> allocationStats = new AtomicReference<>(); @@ -211,17 +309,11 @@ public void testGetStatsWithCachingEnabled() throws Exception { final Runnable resetExpectedAllocationStats = () -> { final var stats = Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats()); allocationStats.set(stats); - when(allocationStatsService.stats()).thenReturn(stats); + when(allocationStatsService.stats(argThat(r -> true))).thenReturn(stats); }; final CheckedConsumer, Exception> threadTask = l -> { - final var request = new TransportGetAllocationStatsAction.Request( - TEST_REQUEST_TIMEOUT, - new TaskId(randomIdentifier(), randomNonNegativeLong()), - EnumSet.of(Metric.ALLOCATIONS) - ); - - action.masterOperation(mock(Task.class), request, ClusterState.EMPTY_STATE, l.map(response -> { + action.masterOperation(getTask(), getRequest(), ClusterState.EMPTY_STATE, l.map(response -> { assertSame("Expected the cached allocation stats to be returned", response.getNodeAllocationStats(), allocationStats.get()); return null; })); @@ -230,12 +322,12 @@ public void testGetStatsWithCachingEnabled() throws Exception { // Initial cache miss, all threads should get the same value. resetExpectedAllocationStats.run(); ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask)); - verify(allocationStatsService, times(++numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(++numExpectedAllocationStatsServiceCalls); // Advance the clock to a time less than or equal to the TTL and verify we still get the cached stats. threadPool.setCurrentTimeInMillis(startTimeMillis + between(0, (int) allocationStatsCacheTTL.millis())); ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask)); - verify(allocationStatsService, times(numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(numExpectedAllocationStatsServiceCalls); // Force the cached stats to expire. threadPool.setCurrentTimeInMillis(startTimeMillis + allocationStatsCacheTTL.getMillis() + 1); @@ -243,20 +335,32 @@ public void testGetStatsWithCachingEnabled() throws Exception { // Expect a single call to the stats service on the cache miss. resetExpectedAllocationStats.run(); ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask)); - verify(allocationStatsService, times(++numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(++numExpectedAllocationStatsServiceCalls); // Update the TTL setting to disable the cache, we expect a service call each time. setAllocationStatsCacheTTL(TimeValue.ZERO); safeAwait(threadTask); safeAwait(threadTask); numExpectedAllocationStatsServiceCalls += 2; - verify(allocationStatsService, times(numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(numExpectedAllocationStatsServiceCalls); // Re-enable the cache, only one thread should call the stats service. setAllocationStatsCacheTTL(TimeValue.timeValueMinutes(5)); resetExpectedAllocationStats.run(); ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask)); - verify(allocationStatsService, times(++numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(++numExpectedAllocationStatsServiceCalls); + } + + private void verifyAllocationStatsServiceNumCallsEqualTo(int numCalls) { + verify(allocationStatsService, times(numCalls)).stats(argThat(r -> true)); + } + + private static TransportGetAllocationStatsAction.Request getRequest() { + return new TransportGetAllocationStatsAction.Request(TEST_REQUEST_TIMEOUT, TaskId.EMPTY_TASK_ID, EnumSet.of(Metric.ALLOCATIONS)); + } + + private static CancellableTask getTask() { + return new CancellableTask(randomLong(), "type", "action", "desc", null, Map.of()); } private static class ControlledRelativeTimeThreadPool extends ThreadPool { diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java index 756f29384ec5c..da31162bd1b6e 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.telemetry.TelemetryProvider; import org.elasticsearch.test.ClusterServiceUtils; @@ -88,7 +89,7 @@ public void testShardStats() { new NodeAllocationStatsAndWeightsCalculator(TEST_WRITE_LOAD_FORECASTER, BalancerSettings.DEFAULT) ); assertThat( - service.stats(), + service.stats(() -> {}), allOf( aMapWithSize(1), hasEntry( @@ -97,6 +98,9 @@ public void testShardStats() { ) ) ); + + // Verify that the ensureNotCancelled Runnable is tested during execution. + assertThrows(TaskCancelledException.class, () -> service.stats(() -> { throw new TaskCancelledException("cancelled"); })); } } diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java b/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java index d2f6c05a5d3f6..a9b2109244270 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java @@ -449,8 +449,10 @@ public Map nodesAllocationStatsAndWeights( Metadata metadata, RoutingNodes routingNodes, ClusterInfo clusterInfo, + Runnable ensureNotCancelled, @Nullable DesiredBalance desiredBalance ) { + ensureNotCancelled.run(); return Map.of(); } }; From 8d6f7ccdbe4317b858ace8a663dbd1307e131b59 Mon Sep 17 00:00:00 2001 From: Jeremy Dahlgren Date: Thu, 24 Apr 2025 22:59:21 -0400 Subject: [PATCH 2/7] Update docs/changelog/127371.yaml --- docs/changelog/127371.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/127371.yaml diff --git a/docs/changelog/127371.yaml b/docs/changelog/127371.yaml new file mode 100644 index 0000000000000..10f5f17243193 --- /dev/null +++ b/docs/changelog/127371.yaml @@ -0,0 +1,6 @@ +pr: 127371 +summary: Add cancellation support in `TransportGetAllocationStatsAction` +area: Allocation +type: feature +issues: + - 123248 From b423b7be9f7cdbaef1d4bc3c46341197ccb1f234 Mon Sep 17 00:00:00 2001 From: Jeremy Dahlgren Date: Sat, 26 Apr 2025 20:58:12 -0400 Subject: [PATCH 3/7] onFailure() with TaskCancelledException when cancellation is detected --- .../util/CancellableSingleObjectCache.java | 16 ++++++++-- .../CancellableSingleObjectCacheTests.java | 31 +++++++++++++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java b/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java index fa8ec26bbad2c..b89587192a061 100644 --- a/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java +++ b/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java @@ -100,6 +100,13 @@ protected boolean isFresh(Key currentKey, Key newKey) { return currentKey.equals(newKey); } + /** + * Sets the currently cached item reference to {@code null}, which will result in a {@code refresh()} on the next {@code get()} call. + */ + protected void clearCurrentCachedItem() { + this.currentCachedItemRef.set(null); + } + /** * Start a retrieval for the value associated with the given {@code input}, and pass it to the given {@code listener}. *

@@ -110,7 +117,8 @@ protected boolean isFresh(Key currentKey, Key newKey) { * * @param input The input to compute the desired value, converted to a {@link Key} to determine if the value that's currently * cached or pending is fresh enough. - * @param isCancelled Returns {@code true} if the listener no longer requires the value being computed. + * @param isCancelled Returns {@code true} if the listener no longer requires the value being computed. The listener is expected to be + * completed as soon as possible when cancellation is detected. * @param listener The listener to notify when the desired value becomes available. */ public final void get(Input input, BooleanSupplier isCancelled, ActionListener listener) { @@ -230,11 +238,15 @@ boolean addListener(ActionListener listener, BooleanSupplier isCancelled) ActionListener.completeWith(listener, future::actionResult); } else { // Refresh is still pending; it's not cancelled because there are still references. - future.addListener(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext)); + final var cancellableListener = ActionListener.notifyOnce( + ContextPreservingActionListener.wrapPreservingContext(listener, threadContext) + ); + future.addListener(cancellableListener); final AtomicBoolean released = new AtomicBoolean(); cancellationChecks.add(() -> { if (released.get() == false && isCancelled.getAsBoolean() && released.compareAndSet(false, true)) { decRef(); + cancellableListener.onFailure(new TaskCancelledException("task cancelled")); } }); } diff --git a/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java b/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java index ecb74009682d3..a0c0bc64b4c8f 100644 --- a/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java @@ -101,8 +101,35 @@ public void testListenerCompletedByRefreshEvenIfDiscarded() throws ExecutionExce testCache.completeNextRefresh("foo", 1); assertThat(future2.result(), equalTo(1)); - // ... and the original listener is also completed successfully - assertThat(future1.result(), sameInstance(future2.result())); + // We expect the first listener to have been completed with a cancellation exception when detected in the ensureNotCancelled() call. + assertTrue(future1.isDone()); + expectThrows(ExecutionException.class, TaskCancelledException.class, future1::result); + } + + public void testBothListenersReceiveTaskCancelledExceptionWhenBothSupersededAndNewTasksAreCancelled() { + final TestCache testCache = new TestCache(); + + // This computation is superseded and then cancelled. + final AtomicBoolean isCancelled = new AtomicBoolean(); + final TestFuture future1 = new TestFuture(); + testCache.get("foo", isCancelled::get, future1); + testCache.assertPendingRefreshes(1); + + // A second get() call that supersedes the original refresh and starts another one, but will be cancelled as well. + final TestFuture future2 = new TestFuture(); + testCache.get("bar", isCancelled::get, future2); + testCache.assertPendingRefreshes(2); + + testCache.assertNextRefreshCancelled(); + assertFalse(future1.isDone()); + testCache.assertPendingRefreshes(1); + assertFalse(future2.isDone()); + + isCancelled.set(true); + // This next refresh should also fail with a cancellation exception. + testCache.completeNextRefresh("bar", 1); + expectThrows(ExecutionException.class, TaskCancelledException.class, future1::result); + expectThrows(ExecutionException.class, TaskCancelledException.class, future2::result); } public void testListenerCompletedWithCancellationExceptionIfRefreshCancelled() throws ExecutionException { From 91df7175b8156207979771ee61d1de9245549d6d Mon Sep 17 00:00:00 2001 From: Jeremy Dahlgren Date: Sat, 26 Apr 2025 20:58:43 -0400 Subject: [PATCH 4/7] Refactor to use CancellableSingleObjectCache --- .../TransportGetAllocationStatsAction.java | 160 +++++------------- ...ransportGetAllocationStatsActionTests.java | 82 ++------- 2 files changed, 64 insertions(+), 178 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java index de2e3e7758bd5..42789094f2239 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java @@ -17,7 +17,6 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.master.MasterNodeReadRequest; import org.elasticsearch.action.support.master.TransportMasterNodeReadAction; @@ -31,24 +30,22 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.util.CancellableSingleObjectCache; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.io.IOException; -import java.util.ArrayList; import java.util.EnumSet; -import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; +import java.util.concurrent.ExecutorService; +import java.util.function.BooleanSupplier; public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAction< TransportGetAllocationStatsAction.Request, @@ -67,10 +64,7 @@ public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAc ); private final AllocationStatsCache allocationStatsCache; - private final Consumer>> allocationStatsSupplier; private final DiskThresholdSettings diskThresholdSettings; - private SubscribableListener waitingListeners; - private List tasksList; @Inject public TransportGetAllocationStatsAction( @@ -92,21 +86,7 @@ public TransportGetAllocationStatsAction( // very cheaply. EsExecutors.DIRECT_EXECUTOR_SERVICE ); - final var managementExecutor = threadPool.executor(ThreadPool.Names.MANAGEMENT); - this.allocationStatsCache = new AllocationStatsCache(threadPool, DEFAULT_CACHE_TTL); - this.allocationStatsSupplier = l -> { - final var cachedStats = allocationStatsCache.get(); - if (cachedStats != null) { - l.onResponse(cachedStats); - return; - } - - managementExecutor.execute(ActionRunnable.supply(l, () -> { - final var stats = allocationStatsService.stats(this::ensureNotCancelled); - allocationStatsCache.put(stats); - return stats; - })); - }; + this.allocationStatsCache = new AllocationStatsCache(threadPool, allocationStatsService, DEFAULT_CACHE_TTL); this.diskThresholdSettings = new DiskThresholdSettings(clusterService.getSettings(), clusterService.getClusterSettings()); clusterService.getClusterSettings().initializeAndWatch(CACHE_TTL_SETTING, this.allocationStatsCache::setTTL); } @@ -125,65 +105,16 @@ protected void doExecute(Task task, Request request, ActionListener li protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { // NB we are still on a transport thread here - if adding more functionality here make sure to fork to a different pool - if (request.metrics().contains(Metric.ALLOCATIONS) == false) { - listener.onResponse(statsToResponse(Map.of(), request)); - return; - } - // Perform a cheap check for the cached stats up front. - final var cachedStats = allocationStatsCache.get(); - if (cachedStats != null) { - listener.onResponse(statsToResponse(cachedStats, request)); - return; - } - assert task instanceof CancellableTask; - final var wrappedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()); - final var taskListenerPair = new TaskListenerPair((CancellableTask) task, wrappedListener); - - synchronized (this) { - if (waitingListeners != null) { - tasksList.add(taskListenerPair); - waitingListeners.addListener(wrappedListener); - return; - } - - tasksList = new ArrayList<>(); - waitingListeners = new SubscribableListener<>(); - tasksList.add(taskListenerPair); - waitingListeners.addListener(ActionListener.runBefore(wrappedListener, () -> { - synchronized (this) { - waitingListeners = null; - tasksList = null; - } - })); - } + final var cancellableTask = (CancellableTask) task; - SubscribableListener.newForked(allocationStatsSupplier::accept) - .andThenApply(stats -> statsToResponse(stats, request)) - .addListener(waitingListeners); - } - - private Response statsToResponse(Map stats, Request request) { - return new Response(stats, request.metrics().contains(Metric.FS) ? diskThresholdSettings : null); - } + final SubscribableListener> allocationStatsStep = request.metrics().contains(Metric.ALLOCATIONS) + ? SubscribableListener.newForked(l -> allocationStatsCache.get(cancellableTask::isCancelled, l)) + : SubscribableListener.newSucceeded(Map.of()); - private void ensureNotCancelled() { - final int count; - synchronized (this) { - count = tasksList.size(); - } - boolean allTasksCancelled = true; - // Check each task to give each task a chance to invoke their listener (once) when cancelled. - for (int i = 0; i < count; ++i) { - final TaskListenerPair taskPair; - synchronized (this) { - taskPair = tasksList.get(i); - } - allTasksCancelled &= taskPair.isCancelled(); - } - if (allTasksCancelled) { - throw new TaskCancelledException("task cancelled"); - } + allocationStatsStep.andThenApply( + allocationStats -> new Response(allocationStats, request.metrics().contains(Metric.FS) ? diskThresholdSettings : null) + ).addListener(listener); } @Override @@ -273,60 +204,61 @@ public DiskThresholdSettings getDiskThresholdSettings() { } } - private record CachedAllocationStats(Map stats, long timestampMillis) {} - - private static class AllocationStatsCache { + private static class AllocationStatsCache extends CancellableSingleObjectCache> { private volatile long ttlMillis; private final ThreadPool threadPool; - private final AtomicReference cachedStats; + private final ExecutorService executorService; + private final AllocationStatsService allocationStatsService; - AllocationStatsCache(ThreadPool threadPool, TimeValue ttl) { + AllocationStatsCache(ThreadPool threadPool, AllocationStatsService allocationStatsService, TimeValue ttl) { + super(threadPool.getThreadContext()); this.threadPool = threadPool; - this.cachedStats = new AtomicReference<>(); + this.executorService = threadPool.executor(ThreadPool.Names.MANAGEMENT); + this.allocationStatsService = allocationStatsService; setTTL(ttl); } void setTTL(TimeValue ttl) { ttlMillis = ttl.millis(); - if (ttlMillis == 0L) { - cachedStats.set(null); - } + clearCacheIfDisabled(); } - Map get() { - if (ttlMillis == 0L) { - return null; - } - - // We don't set the atomic ref to null here upon expiration since we know it is about to be replaced with a fresh instance. - final var stats = cachedStats.get(); - return stats == null || threadPool.relativeTimeInMillis() - stats.timestampMillis > ttlMillis ? null : stats.stats; + void get(BooleanSupplier isCancelled, ActionListener> listener) { + get(threadPool.relativeTimeInMillis(), isCancelled, listener); } - void put(Map stats) { - if (ttlMillis > 0L) { - cachedStats.set(new CachedAllocationStats(stats, threadPool.relativeTimeInMillis())); + @Override + protected void refresh( + Long aLong, + Runnable ensureNotCancelled, + BooleanSupplier supersedeIfStale, + ActionListener> listener + ) { + if (supersedeIfStale.getAsBoolean() == false) { + executorService.execute( + ActionRunnable.supply( + // If caching is disabled the item is only cached long enough to prevent duplicate concurrent requests. + ActionListener.runBefore(listener, this::clearCacheIfDisabled), + () -> allocationStatsService.stats(ensureNotCancelled) + ) + ); } } - } - private static class TaskListenerPair { - private final CancellableTask task; - private final ActionListener listener; - private boolean detectedCancellation; + @Override + protected Long getKey(Long timestampMillis) { + return timestampMillis; + } - TaskListenerPair(CancellableTask task, ActionListener listener) { - this.task = task; - this.listener = listener; - this.detectedCancellation = false; + @Override + protected boolean isFresh(Long currentKey, Long newKey) { + return ttlMillis == 0 || newKey - currentKey <= ttlMillis; } - boolean isCancelled() { - if (detectedCancellation == false && task.isCancelled()) { - detectedCancellation = true; - listener.onFailure(new TaskCancelledException("task cancelled")); + private void clearCacheIfDisabled() { + if (ttlMillis == 0) { + clearCurrentCachedItem(); } - return task.isCancelled(); } } } diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java index b37bc92c4729c..b327c322182e1 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.routing.allocation.AllocationStatsService; import org.elasticsearch.cluster.routing.allocation.NodeAllocationStats; @@ -39,7 +38,6 @@ import org.junit.Before; import org.mockito.ArgumentCaptor; -import java.util.Arrays; import java.util.EnumSet; import java.util.List; import java.util.Map; @@ -196,23 +194,23 @@ public void testDeduplicatesStatsComputations() throws InterruptedException { } } - public void testAllTasksCancelledStopsComputationSingleThread() throws InterruptedException { - runAllTasksCancelledStopsComputationTestForNumThreads(1, false); + public void testAllTasksCancelledCacheEnabled() throws InterruptedException { + runTestWithCancelledTasks(between(2, 10), false, true); } - public void testAllTasksCancelledStopsComputationMultipleThreads() throws InterruptedException { - runAllTasksCancelledStopsComputationTestForNumThreads(between(2, 10), false); + public void testAllTasksCancelledCacheDisabled() throws InterruptedException { + runTestWithCancelledTasks(between(2, 10), true, true); } - public void testAllTasksCancelledStopsComputationSingleThreadCacheDisabled() throws InterruptedException { - runAllTasksCancelledStopsComputationTestForNumThreads(1, true); + public void testSomeTasksCancelledCacheEnabled() throws InterruptedException { + runTestWithCancelledTasks(between(2, 10), false, false); } - public void testAllTasksCancelledStopsComputationMultipleThreadsCacheDisabled() throws InterruptedException { - runAllTasksCancelledStopsComputationTestForNumThreads(between(2, 10), true); + public void testSomeTasksCancelledCacheDisabled() throws InterruptedException { + runTestWithCancelledTasks(between(2, 10), true, false); } - private void runAllTasksCancelledStopsComputationTestForNumThreads(final int numThreads, final boolean cacheDisabled) + private void runTestWithCancelledTasks(final int numThreads, final boolean cacheDisabled, final boolean cancelAllTasks) throws InterruptedException { if (cacheDisabled) { disableAllocationStatsCache(); @@ -220,52 +218,6 @@ private void runAllTasksCancelledStopsComputationTestForNumThreads(final int num final var isExecuting = new AtomicBoolean(); final var ensureNotCancelledCaptor = ArgumentCaptor.forClass(Runnable.class); final var tasks = new CancellableTask[numThreads]; - - when(allocationStatsService.stats(ensureNotCancelledCaptor.capture())).thenAnswer(invocation -> { - try { - assertTrue(isExecuting.compareAndSet(false, true)); - Arrays.stream(tasks).forEach(task -> TaskCancelHelper.cancel(task, "cancelled")); - ensureNotCancelledCaptor.getValue().run(); - fail("expected computation to stop when all tasks are cancelled"); - return null; - } finally { - Thread.yield(); - assertTrue(isExecuting.compareAndSet(true, false)); - } - }); - - ESTestCase.startInParallel(numThreads, threadNumber -> { - tasks[threadNumber] = getTask(); - final SubscribableListener listener = SubscribableListener.newForked( - l -> action.masterOperation(tasks[threadNumber], getRequest(), ClusterState.EMPTY_STATE, l) - ); - safeAwaitFailure(listener); - }); - } - - public void testRunSomeTasksCancelledForSingleThread() throws InterruptedException { - runSomeTasksCancelledForNumThreads(1, false); - } - - public void testRunSomeTasksCancelledForMultipleThreads() throws InterruptedException { - runSomeTasksCancelledForNumThreads(between(2, 10), false); - } - - public void testRunSomeTasksCancelledForSingleThreadCacheDisabled() throws InterruptedException { - runSomeTasksCancelledForNumThreads(1, true); - } - - public void testRunSomeTasksCancelledForMultipleThreadsCacheDisabled() throws InterruptedException { - runSomeTasksCancelledForNumThreads(between(2, 10), true); - } - - private void runSomeTasksCancelledForNumThreads(final int numThreads, final boolean cacheDisabled) throws InterruptedException { - if (cacheDisabled) { - disableAllocationStatsCache(); - } - final var isExecuting = new AtomicBoolean(); - final var ensureNotCancelledCaptor = ArgumentCaptor.forClass(Runnable.class); - final var tasks = new CancellableTask[numThreads]; final var cancellations = new boolean[numThreads]; final var stats = Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats()); @@ -287,17 +239,17 @@ private void runSomeTasksCancelledForNumThreads(final int numThreads, final bool ESTestCase.startInParallel(numThreads, threadNumber -> { tasks[threadNumber] = getTask(); - cancellations[threadNumber] = threadNumber > 0; - final SubscribableListener listener = SubscribableListener.newForked( - l -> action.masterOperation(tasks[threadNumber], getRequest(), ClusterState.EMPTY_STATE, l) - ); - listener.addListener(ActionListener.wrap(response -> { assertSame(stats, response.getNodeAllocationStats()); }, e -> { + cancellations[threadNumber] = cancelAllTasks || randomBoolean(); + final ActionListener listener = ActionListener.wrap(response -> { + assertSame(stats, response.getNodeAllocationStats()); + }, e -> { if (e instanceof TaskCancelledException) { assertTrue("got an unexpected cancellation exception for thread " + threadNumber, cancellations[threadNumber]); } else { fail(e); } - })); + }); + ActionListener.run(listener, l -> action.masterOperation(tasks[threadNumber], getRequest(), ClusterState.EMPTY_STATE, l)); }); } @@ -345,7 +297,9 @@ public void testGetStatsWithCachingEnabled() throws Exception { verifyAllocationStatsServiceNumCallsEqualTo(numExpectedAllocationStatsServiceCalls); // Re-enable the cache, only one thread should call the stats service. - setAllocationStatsCacheTTL(TimeValue.timeValueMinutes(5)); + final var newTTL = TimeValue.timeValueMinutes(5); + setAllocationStatsCacheTTL(newTTL); + threadPool.setCurrentTimeInMillis(threadPool.relativeTimeInMillis() + newTTL.getMillis() + 1); resetExpectedAllocationStats.run(); ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask)); verifyAllocationStatsServiceNumCallsEqualTo(++numExpectedAllocationStatsServiceCalls); From 001a2b73055710858ff7fc26db97dee47359f7d1 Mon Sep 17 00:00:00 2001 From: Jeremy Dahlgren Date: Mon, 19 May 2025 17:35:28 -0400 Subject: [PATCH 5/7] Add test for clearCurrentCachedItem() --- .../CancellableSingleObjectCacheTests.java | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java b/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java index a0c0bc64b4c8f..8760b0d9faa08 100644 --- a/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java @@ -448,6 +448,26 @@ protected String getKey(String s) { expectThrows(ExecutionException.class, TaskCancelledException.class, cancelledFuture::result); } + public void testClearCurrentCachedItem() throws ExecutionException { + final TestCache testCache = new TestCache(); + + // The first get() calls the refresh function. + final TestFuture future0 = new TestFuture(); + testCache.get("foo", () -> false, future0); + testCache.assertPendingRefreshes(1); + testCache.completeNextRefresh("foo", 1); + assertThat(future0.result(), equalTo(1)); + + testCache.clearCurrentCachedItem(); + + // The second get() with a matching key will execute a refresh since the cached item was cleared. + final TestFuture future1 = new TestFuture(); + testCache.get("foo", () -> false, future1); + testCache.assertPendingRefreshes(1); + testCache.completeNextRefresh("foo", 2); + assertThat(future1.result(), equalTo(2)); + } + private static final ThreadContext testThreadContext = new ThreadContext(Settings.EMPTY); private static class TestCache extends CancellableSingleObjectCache { From 7511df2ec2f646ef94c804949c1fd3f928d5e165 Mon Sep 17 00:00:00 2001 From: Jeremy Dahlgren Date: Mon, 19 May 2025 17:35:58 -0400 Subject: [PATCH 6/7] Address code review comments --- .../allocation/TransportGetAllocationStatsAction.java | 8 ++++---- .../allocator/DesiredBalanceShardsAllocator.java | 9 ++++++++- .../common/util/CancellableSingleObjectCache.java | 2 +- .../TransportGetAllocationStatsActionTests.java | 3 ++- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java index 42789094f2239..e46762c9e97ab 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java @@ -44,7 +44,7 @@ import java.io.IOException; import java.util.EnumSet; import java.util.Map; -import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executor; import java.util.function.BooleanSupplier; public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAction< @@ -207,13 +207,13 @@ public DiskThresholdSettings getDiskThresholdSettings() { private static class AllocationStatsCache extends CancellableSingleObjectCache> { private volatile long ttlMillis; private final ThreadPool threadPool; - private final ExecutorService executorService; + private final Executor executor; private final AllocationStatsService allocationStatsService; AllocationStatsCache(ThreadPool threadPool, AllocationStatsService allocationStatsService, TimeValue ttl) { super(threadPool.getThreadContext()); this.threadPool = threadPool; - this.executorService = threadPool.executor(ThreadPool.Names.MANAGEMENT); + this.executor = threadPool.executor(ThreadPool.Names.MANAGEMENT); this.allocationStatsService = allocationStatsService; setTTL(ttl); } @@ -235,7 +235,7 @@ protected void refresh( ActionListener> listener ) { if (supersedeIfStale.getAsBoolean() == false) { - executorService.execute( + executor.execute( ActionRunnable.supply( // If caching is disabled the item is only cached long enough to prevent duplicate concurrent requests. ActionListener.runBefore(listener, this::clearCacheIfDisabled), diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java index d70802f0b64f9..515da761d8696 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java @@ -391,6 +391,13 @@ public void resetDesiredBalance() { resetCurrentDesiredBalance = true; } + /** + * Used as the argument for the {@code ensureNotCancelled} {@code Runnable} when calling the + * {@code nodeAllocationStatsAndWeightsCalculator} since there is no cancellation mechanism when called from + * {@code updateDesireBalanceMetrics()}. + */ + private static final Runnable NEVER_CANCELLED = () -> {}; + private void updateDesireBalanceMetrics( DesiredBalance desiredBalance, RoutingAllocation routingAllocation, @@ -400,7 +407,7 @@ private void updateDesireBalanceMetrics( routingAllocation.metadata(), routingAllocation.routingNodes(), routingAllocation.clusterInfo(), - () -> {}, + NEVER_CANCELLED, desiredBalance ); Map filteredNodeAllocationStatsAndWeights = diff --git a/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java b/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java index b89587192a061..2e33dc47f0d60 100644 --- a/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java +++ b/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java @@ -103,7 +103,7 @@ protected boolean isFresh(Key currentKey, Key newKey) { /** * Sets the currently cached item reference to {@code null}, which will result in a {@code refresh()} on the next {@code get()} call. */ - protected void clearCurrentCachedItem() { + protected final void clearCurrentCachedItem() { this.currentCachedItemRef.set(null); } diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java index b327c322182e1..6909e916dd481 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java @@ -51,6 +51,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.not; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -306,7 +307,7 @@ public void testGetStatsWithCachingEnabled() throws Exception { } private void verifyAllocationStatsServiceNumCallsEqualTo(int numCalls) { - verify(allocationStatsService, times(numCalls)).stats(argThat(r -> true)); + verify(allocationStatsService, times(numCalls)).stats(any()); } private static TransportGetAllocationStatsAction.Request getRequest() { From 2e7fa5052e20d4fbf4dba8ed143c623d116d48e2 Mon Sep 17 00:00:00 2001 From: Jeremy Dahlgren Date: Mon, 19 May 2025 17:43:13 -0400 Subject: [PATCH 7/7] Use ArgumentMatchers.any() --- .../allocation/TransportGetAllocationStatsActionTests.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java index 6909e916dd481..133cf5d648611 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java @@ -52,7 +52,6 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.not; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -133,7 +132,7 @@ public void testReturnsOnlyRequestedStats() throws Exception { )) { var request = new TransportGetAllocationStatsAction.Request(TimeValue.ONE_MINUTE, TaskId.EMPTY_TASK_ID, metrics); - when(allocationStatsService.stats(argThat(r -> true))).thenReturn( + when(allocationStatsService.stats(any())).thenReturn( Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats()) ); @@ -161,7 +160,7 @@ public void testDeduplicatesStatsComputations() throws InterruptedException { disableAllocationStatsCache(); final var requestCounter = new AtomicInteger(); final var isExecuting = new AtomicBoolean(); - when(allocationStatsService.stats(argThat(r -> true))).thenAnswer(invocation -> { + when(allocationStatsService.stats(any())).thenAnswer(invocation -> { try { assertTrue(isExecuting.compareAndSet(false, true)); assertThat(Thread.currentThread().getName(), containsString("[management]")); @@ -262,7 +261,7 @@ public void testGetStatsWithCachingEnabled() throws Exception { final Runnable resetExpectedAllocationStats = () -> { final var stats = Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats()); allocationStats.set(stats); - when(allocationStatsService.stats(argThat(r -> true))).thenReturn(stats); + when(allocationStatsService.stats(any())).thenReturn(stats); }; final CheckedConsumer, Exception> threadTask = l -> {