diff --git a/docs/changelog/127371.yaml b/docs/changelog/127371.yaml new file mode 100644 index 0000000000000..10f5f17243193 --- /dev/null +++ b/docs/changelog/127371.yaml @@ -0,0 +1,6 @@ +pr: 127371 +summary: Add cancellation support in `TransportGetAllocationStatsAction` +area: Allocation +type: feature +issues: + - 123248 diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java index eecbb3525bda9..e46762c9e97ab 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java @@ -15,7 +15,6 @@ import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; -import org.elasticsearch.action.SingleResultDeduplicator; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.SubscribableListener; @@ -31,10 +30,12 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.util.CancellableSingleObjectCache; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; @@ -43,7 +44,8 @@ import java.io.IOException; import java.util.EnumSet; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.Executor; +import java.util.function.BooleanSupplier; public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAction< TransportGetAllocationStatsAction.Request, @@ -62,7 +64,6 @@ public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAc ); private final AllocationStatsCache allocationStatsCache; - private final SingleResultDeduplicator> allocationStatsSupplier; private final DiskThresholdSettings diskThresholdSettings; @Inject @@ -85,21 +86,7 @@ public TransportGetAllocationStatsAction( // very cheaply. EsExecutors.DIRECT_EXECUTOR_SERVICE ); - final var managementExecutor = threadPool.executor(ThreadPool.Names.MANAGEMENT); - this.allocationStatsCache = new AllocationStatsCache(threadPool, DEFAULT_CACHE_TTL); - this.allocationStatsSupplier = new SingleResultDeduplicator<>(threadPool.getThreadContext(), l -> { - final var cachedStats = allocationStatsCache.get(); - if (cachedStats != null) { - l.onResponse(cachedStats); - return; - } - - managementExecutor.execute(ActionRunnable.supply(l, () -> { - final var stats = allocationStatsService.stats(); - allocationStatsCache.put(stats); - return stats; - })); - }); + this.allocationStatsCache = new AllocationStatsCache(threadPool, allocationStatsService, DEFAULT_CACHE_TTL); this.diskThresholdSettings = new DiskThresholdSettings(clusterService.getSettings(), clusterService.getClusterSettings()); clusterService.getClusterSettings().initializeAndWatch(CACHE_TTL_SETTING, this.allocationStatsCache::setTTL); } @@ -118,8 +105,11 @@ protected void doExecute(Task task, Request request, ActionListener li protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { // NB we are still on a transport thread here - if adding more functionality here make sure to fork to a different pool + assert task instanceof CancellableTask; + final var cancellableTask = (CancellableTask) task; + final SubscribableListener> allocationStatsStep = request.metrics().contains(Metric.ALLOCATIONS) - ? SubscribableListener.newForked(allocationStatsSupplier::execute) + ? SubscribableListener.newForked(l -> allocationStatsCache.get(cancellableTask::isCancelled, l)) : SubscribableListener.newSucceeded(Map.of()); allocationStatsStep.andThenApply( @@ -167,6 +157,11 @@ public EnumSet metrics() { public ActionRequestValidationException validate() { return null; } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, "", parentTaskId, headers); + } } public static class Response extends ActionResponse { @@ -209,39 +204,60 @@ public DiskThresholdSettings getDiskThresholdSettings() { } } - private record CachedAllocationStats(Map stats, long timestampMillis) {} - - private static class AllocationStatsCache { + private static class AllocationStatsCache extends CancellableSingleObjectCache> { private volatile long ttlMillis; private final ThreadPool threadPool; - private final AtomicReference cachedStats; + private final Executor executor; + private final AllocationStatsService allocationStatsService; - AllocationStatsCache(ThreadPool threadPool, TimeValue ttl) { + AllocationStatsCache(ThreadPool threadPool, AllocationStatsService allocationStatsService, TimeValue ttl) { + super(threadPool.getThreadContext()); this.threadPool = threadPool; - this.cachedStats = new AtomicReference<>(); + this.executor = threadPool.executor(ThreadPool.Names.MANAGEMENT); + this.allocationStatsService = allocationStatsService; setTTL(ttl); } void setTTL(TimeValue ttl) { ttlMillis = ttl.millis(); - if (ttlMillis == 0L) { - cachedStats.set(null); - } + clearCacheIfDisabled(); } - Map get() { - if (ttlMillis == 0L) { - return null; + void get(BooleanSupplier isCancelled, ActionListener> listener) { + get(threadPool.relativeTimeInMillis(), isCancelled, listener); + } + + @Override + protected void refresh( + Long aLong, + Runnable ensureNotCancelled, + BooleanSupplier supersedeIfStale, + ActionListener> listener + ) { + if (supersedeIfStale.getAsBoolean() == false) { + executor.execute( + ActionRunnable.supply( + // If caching is disabled the item is only cached long enough to prevent duplicate concurrent requests. + ActionListener.runBefore(listener, this::clearCacheIfDisabled), + () -> allocationStatsService.stats(ensureNotCancelled) + ) + ); } + } - // We don't set the atomic ref to null here upon expiration since we know it is about to be replaced with a fresh instance. - final var stats = cachedStats.get(); - return stats == null || threadPool.relativeTimeInMillis() - stats.timestampMillis > ttlMillis ? null : stats.stats; + @Override + protected Long getKey(Long timestampMillis) { + return timestampMillis; + } + + @Override + protected boolean isFresh(Long currentKey, Long newKey) { + return ttlMillis == 0 || newKey - currentKey <= ttlMillis; } - void put(Map stats) { - if (ttlMillis > 0L) { - cachedStats.set(new CachedAllocationStats(stats, threadPool.relativeTimeInMillis())); + private void clearCacheIfDisabled() { + if (ttlMillis == 0) { + clearCurrentCachedItem(); } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java index 926a6926c9aea..f31ddd36a2e31 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsService.java @@ -47,6 +47,14 @@ public AllocationStatsService( * Returns a map of node IDs to node allocation stats. */ public Map stats() { + return stats(() -> {}); + } + + /** + * Returns a map of node IDs to node allocation stats, promising to execute the provided {@link Runnable} during the computation to + * test for cancellation. + */ + public Map stats(Runnable ensureNotCancelled) { assert Transports.assertNotTransportThread("too expensive for a transport worker"); var clusterState = clusterService.state(); @@ -54,6 +62,7 @@ public Map stats() { clusterState.metadata(), clusterState.getRoutingNodes(), clusterInfoService.getClusterInfo(), + ensureNotCancelled, desiredBalanceSupplier.get() ); return nodesStatsAndWeights.entrySet() diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/NodeAllocationStatsAndWeightsCalculator.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/NodeAllocationStatsAndWeightsCalculator.java index 21e006f76b1d1..c92a65543d0ff 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/NodeAllocationStatsAndWeightsCalculator.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/NodeAllocationStatsAndWeightsCalculator.java @@ -58,6 +58,7 @@ public Map nodesAllocationStatsAndWeights( Metadata metadata, RoutingNodes routingNodes, ClusterInfo clusterInfo, + Runnable ensureNotCancelled, @Nullable DesiredBalance desiredBalance ) { if (metadata.hasAnyIndices()) { @@ -78,6 +79,7 @@ public Map nodesAllocationStatsAndWeights( long forecastedDiskUsage = 0; long currentDiskUsage = 0; for (ShardRouting shardRouting : node) { + ensureNotCancelled.run(); if (shardRouting.relocating()) { // Skip the shard if it is moving off this node. The node running recovery will count it. continue; diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java index e8d8d509282ab..515da761d8696 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java @@ -391,6 +391,13 @@ public void resetDesiredBalance() { resetCurrentDesiredBalance = true; } + /** + * Used as the argument for the {@code ensureNotCancelled} {@code Runnable} when calling the + * {@code nodeAllocationStatsAndWeightsCalculator} since there is no cancellation mechanism when called from + * {@code updateDesireBalanceMetrics()}. + */ + private static final Runnable NEVER_CANCELLED = () -> {}; + private void updateDesireBalanceMetrics( DesiredBalance desiredBalance, RoutingAllocation routingAllocation, @@ -400,6 +407,7 @@ private void updateDesireBalanceMetrics( routingAllocation.metadata(), routingAllocation.routingNodes(), routingAllocation.clusterInfo(), + NEVER_CANCELLED, desiredBalance ); Map filteredNodeAllocationStatsAndWeights = diff --git a/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java b/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java index fa8ec26bbad2c..2e33dc47f0d60 100644 --- a/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java +++ b/server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java @@ -100,6 +100,13 @@ protected boolean isFresh(Key currentKey, Key newKey) { return currentKey.equals(newKey); } + /** + * Sets the currently cached item reference to {@code null}, which will result in a {@code refresh()} on the next {@code get()} call. + */ + protected final void clearCurrentCachedItem() { + this.currentCachedItemRef.set(null); + } + /** * Start a retrieval for the value associated with the given {@code input}, and pass it to the given {@code listener}. *

@@ -110,7 +117,8 @@ protected boolean isFresh(Key currentKey, Key newKey) { * * @param input The input to compute the desired value, converted to a {@link Key} to determine if the value that's currently * cached or pending is fresh enough. - * @param isCancelled Returns {@code true} if the listener no longer requires the value being computed. + * @param isCancelled Returns {@code true} if the listener no longer requires the value being computed. The listener is expected to be + * completed as soon as possible when cancellation is detected. * @param listener The listener to notify when the desired value becomes available. */ public final void get(Input input, BooleanSupplier isCancelled, ActionListener listener) { @@ -230,11 +238,15 @@ boolean addListener(ActionListener listener, BooleanSupplier isCancelled) ActionListener.completeWith(listener, future::actionResult); } else { // Refresh is still pending; it's not cancelled because there are still references. - future.addListener(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext)); + final var cancellableListener = ActionListener.notifyOnce( + ContextPreservingActionListener.wrapPreservingContext(listener, threadContext) + ); + future.addListener(cancellableListener); final AtomicBoolean released = new AtomicBoolean(); cancellationChecks.add(() -> { if (released.get() == false && isCancelled.getAsBoolean() && released.compareAndSet(false, true)) { decRef(); + cancellableListener.onFailure(new TaskCancelledException("task cancelled")); } }); } diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java index d60ac5ca47f6d..133cf5d648611 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java @@ -23,7 +23,9 @@ import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.TimeValue; import org.elasticsearch.node.Node; -import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.TaskCancelHelper; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.telemetry.metric.MeterRegistry; import org.elasticsearch.test.ClusterServiceUtils; @@ -34,6 +36,7 @@ import org.elasticsearch.transport.TransportService; import org.junit.After; import org.junit.Before; +import org.mockito.ArgumentCaptor; import java.util.EnumSet; import java.util.List; @@ -48,6 +51,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.not; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -126,26 +130,22 @@ public void testReturnsOnlyRequestedStats() throws Exception { EnumSet.allOf(Metric.class), EnumSet.copyOf(randomSubsetOf(between(1, Metric.values().length), EnumSet.allOf(Metric.class))) )) { - var request = new TransportGetAllocationStatsAction.Request( - TimeValue.ONE_MINUTE, - new TaskId(randomIdentifier(), randomNonNegativeLong()), - metrics - ); + var request = new TransportGetAllocationStatsAction.Request(TimeValue.ONE_MINUTE, TaskId.EMPTY_TASK_ID, metrics); - when(allocationStatsService.stats()).thenReturn( + when(allocationStatsService.stats(any())).thenReturn( Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats()) ); var future = new PlainActionFuture(); - action.masterOperation(mock(Task.class), request, ClusterState.EMPTY_STATE, future); + action.masterOperation(getTask(), request, ClusterState.EMPTY_STATE, future); var response = future.get(); if (metrics.contains(Metric.ALLOCATIONS)) { assertThat(response.getNodeAllocationStats(), not(anEmptyMap())); - verify(allocationStatsService, times(++expectedNumberOfStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(++expectedNumberOfStatsServiceCalls); } else { assertThat(response.getNodeAllocationStats(), anEmptyMap()); - verify(allocationStatsService, times(expectedNumberOfStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(expectedNumberOfStatsServiceCalls); } if (metrics.contains(Metric.FS)) { @@ -160,7 +160,7 @@ public void testDeduplicatesStatsComputations() throws InterruptedException { disableAllocationStatsCache(); final var requestCounter = new AtomicInteger(); final var isExecuting = new AtomicBoolean(); - when(allocationStatsService.stats()).thenAnswer(invocation -> { + when(allocationStatsService.stats(any())).thenAnswer(invocation -> { try { assertTrue(isExecuting.compareAndSet(false, true)); assertThat(Thread.currentThread().getName(), containsString("[management]")); @@ -180,16 +180,7 @@ public void testDeduplicatesStatsComputations() throws InterruptedException { final var minRequestIndex = requestCounter.get(); final TransportGetAllocationStatsAction.Response response = safeAwait( - l -> action.masterOperation( - mock(Task.class), - new TransportGetAllocationStatsAction.Request( - TEST_REQUEST_TIMEOUT, - TaskId.EMPTY_TASK_ID, - EnumSet.of(Metric.ALLOCATIONS) - ), - ClusterState.EMPTY_STATE, - l - ) + l -> action.masterOperation(getTask(), getRequest(), ClusterState.EMPTY_STATE, l) ); final var requestIndex = Integer.valueOf(response.getNodeAllocationStats().keySet().iterator().next()); @@ -203,6 +194,65 @@ public void testDeduplicatesStatsComputations() throws InterruptedException { } } + public void testAllTasksCancelledCacheEnabled() throws InterruptedException { + runTestWithCancelledTasks(between(2, 10), false, true); + } + + public void testAllTasksCancelledCacheDisabled() throws InterruptedException { + runTestWithCancelledTasks(between(2, 10), true, true); + } + + public void testSomeTasksCancelledCacheEnabled() throws InterruptedException { + runTestWithCancelledTasks(between(2, 10), false, false); + } + + public void testSomeTasksCancelledCacheDisabled() throws InterruptedException { + runTestWithCancelledTasks(between(2, 10), true, false); + } + + private void runTestWithCancelledTasks(final int numThreads, final boolean cacheDisabled, final boolean cancelAllTasks) + throws InterruptedException { + if (cacheDisabled) { + disableAllocationStatsCache(); + } + final var isExecuting = new AtomicBoolean(); + final var ensureNotCancelledCaptor = ArgumentCaptor.forClass(Runnable.class); + final var tasks = new CancellableTask[numThreads]; + final var cancellations = new boolean[numThreads]; + final var stats = Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats()); + + when(allocationStatsService.stats(ensureNotCancelledCaptor.capture())).thenAnswer(invocation -> { + try { + assertTrue(isExecuting.compareAndSet(false, true)); + for (int i = 0; i < numThreads; ++i) { + if (cancellations[i]) { + TaskCancelHelper.cancel(tasks[i], "cancelled"); + } + } + ensureNotCancelledCaptor.getValue().run(); + return stats; + } finally { + Thread.yield(); + assertTrue(isExecuting.compareAndSet(true, false)); + } + }); + + ESTestCase.startInParallel(numThreads, threadNumber -> { + tasks[threadNumber] = getTask(); + cancellations[threadNumber] = cancelAllTasks || randomBoolean(); + final ActionListener listener = ActionListener.wrap(response -> { + assertSame(stats, response.getNodeAllocationStats()); + }, e -> { + if (e instanceof TaskCancelledException) { + assertTrue("got an unexpected cancellation exception for thread " + threadNumber, cancellations[threadNumber]); + } else { + fail(e); + } + }); + ActionListener.run(listener, l -> action.masterOperation(tasks[threadNumber], getRequest(), ClusterState.EMPTY_STATE, l)); + }); + } + public void testGetStatsWithCachingEnabled() throws Exception { final AtomicReference> allocationStats = new AtomicReference<>(); @@ -211,17 +261,11 @@ public void testGetStatsWithCachingEnabled() throws Exception { final Runnable resetExpectedAllocationStats = () -> { final var stats = Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats()); allocationStats.set(stats); - when(allocationStatsService.stats()).thenReturn(stats); + when(allocationStatsService.stats(any())).thenReturn(stats); }; final CheckedConsumer, Exception> threadTask = l -> { - final var request = new TransportGetAllocationStatsAction.Request( - TEST_REQUEST_TIMEOUT, - new TaskId(randomIdentifier(), randomNonNegativeLong()), - EnumSet.of(Metric.ALLOCATIONS) - ); - - action.masterOperation(mock(Task.class), request, ClusterState.EMPTY_STATE, l.map(response -> { + action.masterOperation(getTask(), getRequest(), ClusterState.EMPTY_STATE, l.map(response -> { assertSame("Expected the cached allocation stats to be returned", response.getNodeAllocationStats(), allocationStats.get()); return null; })); @@ -230,12 +274,12 @@ public void testGetStatsWithCachingEnabled() throws Exception { // Initial cache miss, all threads should get the same value. resetExpectedAllocationStats.run(); ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask)); - verify(allocationStatsService, times(++numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(++numExpectedAllocationStatsServiceCalls); // Advance the clock to a time less than or equal to the TTL and verify we still get the cached stats. threadPool.setCurrentTimeInMillis(startTimeMillis + between(0, (int) allocationStatsCacheTTL.millis())); ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask)); - verify(allocationStatsService, times(numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(numExpectedAllocationStatsServiceCalls); // Force the cached stats to expire. threadPool.setCurrentTimeInMillis(startTimeMillis + allocationStatsCacheTTL.getMillis() + 1); @@ -243,20 +287,34 @@ public void testGetStatsWithCachingEnabled() throws Exception { // Expect a single call to the stats service on the cache miss. resetExpectedAllocationStats.run(); ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask)); - verify(allocationStatsService, times(++numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(++numExpectedAllocationStatsServiceCalls); // Update the TTL setting to disable the cache, we expect a service call each time. setAllocationStatsCacheTTL(TimeValue.ZERO); safeAwait(threadTask); safeAwait(threadTask); numExpectedAllocationStatsServiceCalls += 2; - verify(allocationStatsService, times(numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(numExpectedAllocationStatsServiceCalls); // Re-enable the cache, only one thread should call the stats service. - setAllocationStatsCacheTTL(TimeValue.timeValueMinutes(5)); + final var newTTL = TimeValue.timeValueMinutes(5); + setAllocationStatsCacheTTL(newTTL); + threadPool.setCurrentTimeInMillis(threadPool.relativeTimeInMillis() + newTTL.getMillis() + 1); resetExpectedAllocationStats.run(); ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask)); - verify(allocationStatsService, times(++numExpectedAllocationStatsServiceCalls)).stats(); + verifyAllocationStatsServiceNumCallsEqualTo(++numExpectedAllocationStatsServiceCalls); + } + + private void verifyAllocationStatsServiceNumCallsEqualTo(int numCalls) { + verify(allocationStatsService, times(numCalls)).stats(any()); + } + + private static TransportGetAllocationStatsAction.Request getRequest() { + return new TransportGetAllocationStatsAction.Request(TEST_REQUEST_TIMEOUT, TaskId.EMPTY_TASK_ID, EnumSet.of(Metric.ALLOCATIONS)); + } + + private static CancellableTask getTask() { + return new CancellableTask(randomLong(), "type", "action", "desc", null, Map.of()); } private static class ControlledRelativeTimeThreadPool extends ThreadPool { diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java index cf2653bc6c559..4a07b837b08af 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationStatsServiceTests.java @@ -31,6 +31,7 @@ import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.telemetry.TelemetryProvider; import org.elasticsearch.test.ClusterServiceUtils; @@ -92,7 +93,7 @@ public void testShardStats() { ) ); assertThat( - service.stats(), + service.stats(() -> {}), allOf( aMapWithSize(1), hasEntry( @@ -101,6 +102,9 @@ public void testShardStats() { ) ) ); + + // Verify that the ensureNotCancelled Runnable is tested during execution. + assertThrows(TaskCancelledException.class, () -> service.stats(() -> { throw new TaskCancelledException("cancelled"); })); } } diff --git a/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java b/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java index ecb74009682d3..8760b0d9faa08 100644 --- a/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/CancellableSingleObjectCacheTests.java @@ -101,8 +101,35 @@ public void testListenerCompletedByRefreshEvenIfDiscarded() throws ExecutionExce testCache.completeNextRefresh("foo", 1); assertThat(future2.result(), equalTo(1)); - // ... and the original listener is also completed successfully - assertThat(future1.result(), sameInstance(future2.result())); + // We expect the first listener to have been completed with a cancellation exception when detected in the ensureNotCancelled() call. + assertTrue(future1.isDone()); + expectThrows(ExecutionException.class, TaskCancelledException.class, future1::result); + } + + public void testBothListenersReceiveTaskCancelledExceptionWhenBothSupersededAndNewTasksAreCancelled() { + final TestCache testCache = new TestCache(); + + // This computation is superseded and then cancelled. + final AtomicBoolean isCancelled = new AtomicBoolean(); + final TestFuture future1 = new TestFuture(); + testCache.get("foo", isCancelled::get, future1); + testCache.assertPendingRefreshes(1); + + // A second get() call that supersedes the original refresh and starts another one, but will be cancelled as well. + final TestFuture future2 = new TestFuture(); + testCache.get("bar", isCancelled::get, future2); + testCache.assertPendingRefreshes(2); + + testCache.assertNextRefreshCancelled(); + assertFalse(future1.isDone()); + testCache.assertPendingRefreshes(1); + assertFalse(future2.isDone()); + + isCancelled.set(true); + // This next refresh should also fail with a cancellation exception. + testCache.completeNextRefresh("bar", 1); + expectThrows(ExecutionException.class, TaskCancelledException.class, future1::result); + expectThrows(ExecutionException.class, TaskCancelledException.class, future2::result); } public void testListenerCompletedWithCancellationExceptionIfRefreshCancelled() throws ExecutionException { @@ -421,6 +448,26 @@ protected String getKey(String s) { expectThrows(ExecutionException.class, TaskCancelledException.class, cancelledFuture::result); } + public void testClearCurrentCachedItem() throws ExecutionException { + final TestCache testCache = new TestCache(); + + // The first get() calls the refresh function. + final TestFuture future0 = new TestFuture(); + testCache.get("foo", () -> false, future0); + testCache.assertPendingRefreshes(1); + testCache.completeNextRefresh("foo", 1); + assertThat(future0.result(), equalTo(1)); + + testCache.clearCurrentCachedItem(); + + // The second get() with a matching key will execute a refresh since the cached item was cleared. + final TestFuture future1 = new TestFuture(); + testCache.get("foo", () -> false, future1); + testCache.assertPendingRefreshes(1); + testCache.completeNextRefresh("foo", 2); + assertThat(future1.result(), equalTo(2)); + } + private static final ThreadContext testThreadContext = new ThreadContext(Settings.EMPTY); private static class TestCache extends CancellableSingleObjectCache { diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java b/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java index e6a3f7664bd28..8a49db652374e 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java @@ -453,8 +453,10 @@ public Map nodesAllocationStatsAndWeights( Metadata metadata, RoutingNodes routingNodes, ClusterInfo clusterInfo, + Runnable ensureNotCancelled, @Nullable DesiredBalance desiredBalance ) { + ensureNotCancelled.run(); return Map.of(); } };