Skip to content

Commit 91df717

Browse files
Refactor to use CancellableSingleObjectCache
1 parent b423b7b commit 91df717

File tree

2 files changed

+64
-178
lines changed

2 files changed

+64
-178
lines changed

server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java

Lines changed: 46 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.action.ActionType;
1818
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric;
1919
import org.elasticsearch.action.support.ActionFilters;
20-
import org.elasticsearch.action.support.ContextPreservingActionListener;
2120
import org.elasticsearch.action.support.SubscribableListener;
2221
import org.elasticsearch.action.support.master.MasterNodeReadRequest;
2322
import org.elasticsearch.action.support.master.TransportMasterNodeReadAction;
@@ -31,24 +30,22 @@
3130
import org.elasticsearch.common.io.stream.StreamInput;
3231
import org.elasticsearch.common.io.stream.StreamOutput;
3332
import org.elasticsearch.common.settings.Setting;
33+
import org.elasticsearch.common.util.CancellableSingleObjectCache;
3434
import org.elasticsearch.common.util.concurrent.EsExecutors;
3535
import org.elasticsearch.core.Nullable;
3636
import org.elasticsearch.core.TimeValue;
3737
import org.elasticsearch.injection.guice.Inject;
3838
import org.elasticsearch.tasks.CancellableTask;
3939
import org.elasticsearch.tasks.Task;
40-
import org.elasticsearch.tasks.TaskCancelledException;
4140
import org.elasticsearch.tasks.TaskId;
4241
import org.elasticsearch.threadpool.ThreadPool;
4342
import org.elasticsearch.transport.TransportService;
4443

4544
import java.io.IOException;
46-
import java.util.ArrayList;
4745
import java.util.EnumSet;
48-
import java.util.List;
4946
import java.util.Map;
50-
import java.util.concurrent.atomic.AtomicReference;
51-
import java.util.function.Consumer;
47+
import java.util.concurrent.ExecutorService;
48+
import java.util.function.BooleanSupplier;
5249

5350
public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAction<
5451
TransportGetAllocationStatsAction.Request,
@@ -67,10 +64,7 @@ public class TransportGetAllocationStatsAction extends TransportMasterNodeReadAc
6764
);
6865

6966
private final AllocationStatsCache allocationStatsCache;
70-
private final Consumer<ActionListener<Map<String, NodeAllocationStats>>> allocationStatsSupplier;
7167
private final DiskThresholdSettings diskThresholdSettings;
72-
private SubscribableListener<Response> waitingListeners;
73-
private List<TaskListenerPair> tasksList;
7468

7569
@Inject
7670
public TransportGetAllocationStatsAction(
@@ -92,21 +86,7 @@ public TransportGetAllocationStatsAction(
9286
// very cheaply.
9387
EsExecutors.DIRECT_EXECUTOR_SERVICE
9488
);
95-
final var managementExecutor = threadPool.executor(ThreadPool.Names.MANAGEMENT);
96-
this.allocationStatsCache = new AllocationStatsCache(threadPool, DEFAULT_CACHE_TTL);
97-
this.allocationStatsSupplier = l -> {
98-
final var cachedStats = allocationStatsCache.get();
99-
if (cachedStats != null) {
100-
l.onResponse(cachedStats);
101-
return;
102-
}
103-
104-
managementExecutor.execute(ActionRunnable.supply(l, () -> {
105-
final var stats = allocationStatsService.stats(this::ensureNotCancelled);
106-
allocationStatsCache.put(stats);
107-
return stats;
108-
}));
109-
};
89+
this.allocationStatsCache = new AllocationStatsCache(threadPool, allocationStatsService, DEFAULT_CACHE_TTL);
11090
this.diskThresholdSettings = new DiskThresholdSettings(clusterService.getSettings(), clusterService.getClusterSettings());
11191
clusterService.getClusterSettings().initializeAndWatch(CACHE_TTL_SETTING, this.allocationStatsCache::setTTL);
11292
}
@@ -125,65 +105,16 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
125105
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) throws Exception {
126106
// NB we are still on a transport thread here - if adding more functionality here make sure to fork to a different pool
127107

128-
if (request.metrics().contains(Metric.ALLOCATIONS) == false) {
129-
listener.onResponse(statsToResponse(Map.of(), request));
130-
return;
131-
}
132-
// Perform a cheap check for the cached stats up front.
133-
final var cachedStats = allocationStatsCache.get();
134-
if (cachedStats != null) {
135-
listener.onResponse(statsToResponse(cachedStats, request));
136-
return;
137-
}
138-
139108
assert task instanceof CancellableTask;
140-
final var wrappedListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
141-
final var taskListenerPair = new TaskListenerPair((CancellableTask) task, wrappedListener);
142-
143-
synchronized (this) {
144-
if (waitingListeners != null) {
145-
tasksList.add(taskListenerPair);
146-
waitingListeners.addListener(wrappedListener);
147-
return;
148-
}
149-
150-
tasksList = new ArrayList<>();
151-
waitingListeners = new SubscribableListener<>();
152-
tasksList.add(taskListenerPair);
153-
waitingListeners.addListener(ActionListener.runBefore(wrappedListener, () -> {
154-
synchronized (this) {
155-
waitingListeners = null;
156-
tasksList = null;
157-
}
158-
}));
159-
}
109+
final var cancellableTask = (CancellableTask) task;
160110

161-
SubscribableListener.newForked(allocationStatsSupplier::accept)
162-
.andThenApply(stats -> statsToResponse(stats, request))
163-
.addListener(waitingListeners);
164-
}
165-
166-
private Response statsToResponse(Map<String, NodeAllocationStats> stats, Request request) {
167-
return new Response(stats, request.metrics().contains(Metric.FS) ? diskThresholdSettings : null);
168-
}
111+
final SubscribableListener<Map<String, NodeAllocationStats>> allocationStatsStep = request.metrics().contains(Metric.ALLOCATIONS)
112+
? SubscribableListener.newForked(l -> allocationStatsCache.get(cancellableTask::isCancelled, l))
113+
: SubscribableListener.newSucceeded(Map.of());
169114

170-
private void ensureNotCancelled() {
171-
final int count;
172-
synchronized (this) {
173-
count = tasksList.size();
174-
}
175-
boolean allTasksCancelled = true;
176-
// Check each task to give each task a chance to invoke their listener (once) when cancelled.
177-
for (int i = 0; i < count; ++i) {
178-
final TaskListenerPair taskPair;
179-
synchronized (this) {
180-
taskPair = tasksList.get(i);
181-
}
182-
allTasksCancelled &= taskPair.isCancelled();
183-
}
184-
if (allTasksCancelled) {
185-
throw new TaskCancelledException("task cancelled");
186-
}
115+
allocationStatsStep.andThenApply(
116+
allocationStats -> new Response(allocationStats, request.metrics().contains(Metric.FS) ? diskThresholdSettings : null)
117+
).addListener(listener);
187118
}
188119

189120
@Override
@@ -273,60 +204,61 @@ public DiskThresholdSettings getDiskThresholdSettings() {
273204
}
274205
}
275206

276-
private record CachedAllocationStats(Map<String, NodeAllocationStats> stats, long timestampMillis) {}
277-
278-
private static class AllocationStatsCache {
207+
private static class AllocationStatsCache extends CancellableSingleObjectCache<Long, Long, Map<String, NodeAllocationStats>> {
279208
private volatile long ttlMillis;
280209
private final ThreadPool threadPool;
281-
private final AtomicReference<CachedAllocationStats> cachedStats;
210+
private final ExecutorService executorService;
211+
private final AllocationStatsService allocationStatsService;
282212

283-
AllocationStatsCache(ThreadPool threadPool, TimeValue ttl) {
213+
AllocationStatsCache(ThreadPool threadPool, AllocationStatsService allocationStatsService, TimeValue ttl) {
214+
super(threadPool.getThreadContext());
284215
this.threadPool = threadPool;
285-
this.cachedStats = new AtomicReference<>();
216+
this.executorService = threadPool.executor(ThreadPool.Names.MANAGEMENT);
217+
this.allocationStatsService = allocationStatsService;
286218
setTTL(ttl);
287219
}
288220

289221
void setTTL(TimeValue ttl) {
290222
ttlMillis = ttl.millis();
291-
if (ttlMillis == 0L) {
292-
cachedStats.set(null);
293-
}
223+
clearCacheIfDisabled();
294224
}
295225

296-
Map<String, NodeAllocationStats> get() {
297-
if (ttlMillis == 0L) {
298-
return null;
299-
}
300-
301-
// We don't set the atomic ref to null here upon expiration since we know it is about to be replaced with a fresh instance.
302-
final var stats = cachedStats.get();
303-
return stats == null || threadPool.relativeTimeInMillis() - stats.timestampMillis > ttlMillis ? null : stats.stats;
226+
void get(BooleanSupplier isCancelled, ActionListener<Map<String, NodeAllocationStats>> listener) {
227+
get(threadPool.relativeTimeInMillis(), isCancelled, listener);
304228
}
305229

306-
void put(Map<String, NodeAllocationStats> stats) {
307-
if (ttlMillis > 0L) {
308-
cachedStats.set(new CachedAllocationStats(stats, threadPool.relativeTimeInMillis()));
230+
@Override
231+
protected void refresh(
232+
Long aLong,
233+
Runnable ensureNotCancelled,
234+
BooleanSupplier supersedeIfStale,
235+
ActionListener<Map<String, NodeAllocationStats>> listener
236+
) {
237+
if (supersedeIfStale.getAsBoolean() == false) {
238+
executorService.execute(
239+
ActionRunnable.supply(
240+
// If caching is disabled the item is only cached long enough to prevent duplicate concurrent requests.
241+
ActionListener.runBefore(listener, this::clearCacheIfDisabled),
242+
() -> allocationStatsService.stats(ensureNotCancelled)
243+
)
244+
);
309245
}
310246
}
311-
}
312247

313-
private static class TaskListenerPair {
314-
private final CancellableTask task;
315-
private final ActionListener<Response> listener;
316-
private boolean detectedCancellation;
248+
@Override
249+
protected Long getKey(Long timestampMillis) {
250+
return timestampMillis;
251+
}
317252

318-
TaskListenerPair(CancellableTask task, ActionListener<Response> listener) {
319-
this.task = task;
320-
this.listener = listener;
321-
this.detectedCancellation = false;
253+
@Override
254+
protected boolean isFresh(Long currentKey, Long newKey) {
255+
return ttlMillis == 0 || newKey - currentKey <= ttlMillis;
322256
}
323257

324-
boolean isCancelled() {
325-
if (detectedCancellation == false && task.isCancelled()) {
326-
detectedCancellation = true;
327-
listener.onFailure(new TaskCancelledException("task cancelled"));
258+
private void clearCacheIfDisabled() {
259+
if (ttlMillis == 0) {
260+
clearCurrentCachedItem();
328261
}
329-
return task.isCancelled();
330262
}
331263
}
332264
}

server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsActionTests.java

Lines changed: 18 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters.Metric;
1414
import org.elasticsearch.action.support.ActionFilters;
1515
import org.elasticsearch.action.support.PlainActionFuture;
16-
import org.elasticsearch.action.support.SubscribableListener;
1716
import org.elasticsearch.cluster.ClusterState;
1817
import org.elasticsearch.cluster.routing.allocation.AllocationStatsService;
1918
import org.elasticsearch.cluster.routing.allocation.NodeAllocationStats;
@@ -39,7 +38,6 @@
3938
import org.junit.Before;
4039
import org.mockito.ArgumentCaptor;
4140

42-
import java.util.Arrays;
4341
import java.util.EnumSet;
4442
import java.util.List;
4543
import java.util.Map;
@@ -196,76 +194,30 @@ public void testDeduplicatesStatsComputations() throws InterruptedException {
196194
}
197195
}
198196

199-
public void testAllTasksCancelledStopsComputationSingleThread() throws InterruptedException {
200-
runAllTasksCancelledStopsComputationTestForNumThreads(1, false);
197+
public void testAllTasksCancelledCacheEnabled() throws InterruptedException {
198+
runTestWithCancelledTasks(between(2, 10), false, true);
201199
}
202200

203-
public void testAllTasksCancelledStopsComputationMultipleThreads() throws InterruptedException {
204-
runAllTasksCancelledStopsComputationTestForNumThreads(between(2, 10), false);
201+
public void testAllTasksCancelledCacheDisabled() throws InterruptedException {
202+
runTestWithCancelledTasks(between(2, 10), true, true);
205203
}
206204

207-
public void testAllTasksCancelledStopsComputationSingleThreadCacheDisabled() throws InterruptedException {
208-
runAllTasksCancelledStopsComputationTestForNumThreads(1, true);
205+
public void testSomeTasksCancelledCacheEnabled() throws InterruptedException {
206+
runTestWithCancelledTasks(between(2, 10), false, false);
209207
}
210208

211-
public void testAllTasksCancelledStopsComputationMultipleThreadsCacheDisabled() throws InterruptedException {
212-
runAllTasksCancelledStopsComputationTestForNumThreads(between(2, 10), true);
209+
public void testSomeTasksCancelledCacheDisabled() throws InterruptedException {
210+
runTestWithCancelledTasks(between(2, 10), true, false);
213211
}
214212

215-
private void runAllTasksCancelledStopsComputationTestForNumThreads(final int numThreads, final boolean cacheDisabled)
213+
private void runTestWithCancelledTasks(final int numThreads, final boolean cacheDisabled, final boolean cancelAllTasks)
216214
throws InterruptedException {
217215
if (cacheDisabled) {
218216
disableAllocationStatsCache();
219217
}
220218
final var isExecuting = new AtomicBoolean();
221219
final var ensureNotCancelledCaptor = ArgumentCaptor.forClass(Runnable.class);
222220
final var tasks = new CancellableTask[numThreads];
223-
224-
when(allocationStatsService.stats(ensureNotCancelledCaptor.capture())).thenAnswer(invocation -> {
225-
try {
226-
assertTrue(isExecuting.compareAndSet(false, true));
227-
Arrays.stream(tasks).forEach(task -> TaskCancelHelper.cancel(task, "cancelled"));
228-
ensureNotCancelledCaptor.getValue().run();
229-
fail("expected computation to stop when all tasks are cancelled");
230-
return null;
231-
} finally {
232-
Thread.yield();
233-
assertTrue(isExecuting.compareAndSet(true, false));
234-
}
235-
});
236-
237-
ESTestCase.startInParallel(numThreads, threadNumber -> {
238-
tasks[threadNumber] = getTask();
239-
final SubscribableListener<TransportGetAllocationStatsAction.Response> listener = SubscribableListener.newForked(
240-
l -> action.masterOperation(tasks[threadNumber], getRequest(), ClusterState.EMPTY_STATE, l)
241-
);
242-
safeAwaitFailure(listener);
243-
});
244-
}
245-
246-
public void testRunSomeTasksCancelledForSingleThread() throws InterruptedException {
247-
runSomeTasksCancelledForNumThreads(1, false);
248-
}
249-
250-
public void testRunSomeTasksCancelledForMultipleThreads() throws InterruptedException {
251-
runSomeTasksCancelledForNumThreads(between(2, 10), false);
252-
}
253-
254-
public void testRunSomeTasksCancelledForSingleThreadCacheDisabled() throws InterruptedException {
255-
runSomeTasksCancelledForNumThreads(1, true);
256-
}
257-
258-
public void testRunSomeTasksCancelledForMultipleThreadsCacheDisabled() throws InterruptedException {
259-
runSomeTasksCancelledForNumThreads(between(2, 10), true);
260-
}
261-
262-
private void runSomeTasksCancelledForNumThreads(final int numThreads, final boolean cacheDisabled) throws InterruptedException {
263-
if (cacheDisabled) {
264-
disableAllocationStatsCache();
265-
}
266-
final var isExecuting = new AtomicBoolean();
267-
final var ensureNotCancelledCaptor = ArgumentCaptor.forClass(Runnable.class);
268-
final var tasks = new CancellableTask[numThreads];
269221
final var cancellations = new boolean[numThreads];
270222
final var stats = Map.of(randomIdentifier(), NodeAllocationStatsTests.randomNodeAllocationStats());
271223

@@ -287,17 +239,17 @@ private void runSomeTasksCancelledForNumThreads(final int numThreads, final bool
287239

288240
ESTestCase.startInParallel(numThreads, threadNumber -> {
289241
tasks[threadNumber] = getTask();
290-
cancellations[threadNumber] = threadNumber > 0;
291-
final SubscribableListener<TransportGetAllocationStatsAction.Response> listener = SubscribableListener.newForked(
292-
l -> action.masterOperation(tasks[threadNumber], getRequest(), ClusterState.EMPTY_STATE, l)
293-
);
294-
listener.addListener(ActionListener.wrap(response -> { assertSame(stats, response.getNodeAllocationStats()); }, e -> {
242+
cancellations[threadNumber] = cancelAllTasks || randomBoolean();
243+
final ActionListener<TransportGetAllocationStatsAction.Response> listener = ActionListener.wrap(response -> {
244+
assertSame(stats, response.getNodeAllocationStats());
245+
}, e -> {
295246
if (e instanceof TaskCancelledException) {
296247
assertTrue("got an unexpected cancellation exception for thread " + threadNumber, cancellations[threadNumber]);
297248
} else {
298249
fail(e);
299250
}
300-
}));
251+
});
252+
ActionListener.run(listener, l -> action.masterOperation(tasks[threadNumber], getRequest(), ClusterState.EMPTY_STATE, l));
301253
});
302254
}
303255

@@ -345,7 +297,9 @@ public void testGetStatsWithCachingEnabled() throws Exception {
345297
verifyAllocationStatsServiceNumCallsEqualTo(numExpectedAllocationStatsServiceCalls);
346298

347299
// Re-enable the cache, only one thread should call the stats service.
348-
setAllocationStatsCacheTTL(TimeValue.timeValueMinutes(5));
300+
final var newTTL = TimeValue.timeValueMinutes(5);
301+
setAllocationStatsCacheTTL(newTTL);
302+
threadPool.setCurrentTimeInMillis(threadPool.relativeTimeInMillis() + newTTL.getMillis() + 1);
349303
resetExpectedAllocationStats.run();
350304
ESTestCase.startInParallel(between(1, 5), threadNumber -> safeAwait(threadTask));
351305
verifyAllocationStatsServiceNumCallsEqualTo(++numExpectedAllocationStatsServiceCalls);

0 commit comments

Comments
 (0)