From a7804f5fc8f1524a19f7e0fefa8539b0961e96bc Mon Sep 17 00:00:00 2001 From: David Turner Date: Thu, 2 Feb 2023 07:43:51 +0000 Subject: [PATCH 1/2] Extract ThrottledTaskRunner Generalizes `PrioritizedThrottledTaskRunner` slightly: - The throttling behaviour is also useful for tasks which do not complete synchronously. The new `ThrottledTaskRunner` passes a `Releasable` to each task, which until released will prevent spawning further tasks. - The only part that needs the tasks to be `Comparable<>` is the queue. Letting the caller specify the queue means that we can also use the throttling without the prioritisation. --- .../AbstractThrottledTaskRunner.java | 158 +++++++++++++ .../PrioritizedThrottledTaskRunner.java | 156 +++++-------- .../util/concurrent/ThrottledTaskRunner.java | 21 ++ .../AbstractThrottledTaskRunnerTests.java | 209 ++++++++++++++++++ .../PrioritizedThrottledTaskRunnerTests.java | 37 ++-- .../common/util/concurrent/TestBarrier.java | 52 +++++ .../util/concurrent/TestBarrierTests.java | 59 +++++ 7 files changed, 567 insertions(+), 125 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunner.java create mode 100644 server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledTaskRunner.java create mode 100644 server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunnerTests.java create mode 100644 test/framework/src/main/java/org/elasticsearch/common/util/concurrent/TestBarrier.java create mode 100644 test/framework/src/test/java/org/elasticsearch/common/util/concurrent/TestBarrierTests.java diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunner.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunner.java new file mode 100644 index 0000000000000..ea37dad5ba218 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunner.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Strings; + +import java.util.Queue; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * {@link AbstractThrottledTaskRunner} runs the enqueued tasks using the given executor, limiting the number of tasks that are submitted to + * the executor at once. + */ +public class AbstractThrottledTaskRunner> { + private static final Logger logger = LogManager.getLogger(AbstractThrottledTaskRunner.class); + + private final String taskRunnerName; + // The max number of tasks that this runner will schedule to concurrently run on the executor. + private final int maxRunningTasks; + // As we fork off dequeued tasks to the given executor, technically the following counter represents + // the number of the concurrent pollAndSpawn calls currently checking the queue for a task to run. This + // doesn't necessarily correspond to currently running tasks, since a pollAndSpawn could return without + // actually running a task when the queue is empty. + private final AtomicInteger runningTasks = new AtomicInteger(); + private final Queue tasks; + private final Executor executor; + + public AbstractThrottledTaskRunner(final String name, final int maxRunningTasks, final Executor executor, final Queue taskQueue) { + assert maxRunningTasks > 0; + this.taskRunnerName = name; + this.maxRunningTasks = maxRunningTasks; + this.executor = executor; + this.tasks = taskQueue; + } + + /** + * Submits a task for execution. If there are fewer than {@code maxRunningTasks} tasks currently running then this task is immediately + * submitted to the executor. Otherwise this task is enqueued and will be submitted to the executor in turn on completion of some other + * task. + * + * Tasks are executed via their {@link ActionListener#onResponse} method, receiving a {@link Releasable} which must be closed on + * completion of the task. Task which are rejected from their executor are notified via their {@link ActionListener#onFailure} method. + * Neither of these methods may themselves throw exceptions. + */ + public void enqueueTask(final T task) { + logger.trace("[{}] enqueuing task {}", taskRunnerName, task); + tasks.add(task); + // Try to run a task since now there is at least one in the queue. If the maxRunningTasks is + // reached, the task is just enqueued. + pollAndSpawn(); + } + + /** + * Allows certain tasks to force their execution, bypassing the queue-length limit on the executor. See also {@link + * AbstractRunnable#isForceExecution()}. + */ + protected boolean isForceExecution(@SuppressWarnings("unused") /* TODO test this */ T task) { + return false; + } + + private void pollAndSpawn() { + // A pollAndSpawn attempts to run a new task. There could be many concurrent pollAndSpawn calls competing + // to get a "free slot", since we attempt to run a new task on every enqueueTask call and every time an + // existing task is finished. + while (incrementRunningTasks()) { + T task = tasks.poll(); + if (task == null) { + logger.trace("[{}] task queue is empty", taskRunnerName); + // We have taken up a "free slot", but there are no tasks in the queue! This could happen each time a worker + // sees an empty queue after running a task. Decrement to give competing pollAndSpawn calls a chance! + int decremented = runningTasks.decrementAndGet(); + assert decremented >= 0; + // We might have blocked all competing pollAndSpawn calls. This could happen for example when + // maxRunningTasks=1 and a task got enqueued just after checking the queue but before decrementing. + // To be sure, return only if the queue is still empty. If the queue is not empty, this might be the + // only pollAndSpawn call in progress, and returning without peeking would risk ending up with a + // non-empty queue and no workers! + if (tasks.peek() == null) break; + } else { + final boolean isForceExecution = isForceExecution(task); + executor.execute(new AbstractRunnable() { + private boolean rejected; // need not be volatile - if we're rejected then that happens-before calling onAfter + + private final Releasable releasable = Releasables.releaseOnce(() -> { + // To avoid missing to run tasks that are enqueued and waiting, we check the queue again once running + // a task is finished. + int decremented = runningTasks.decrementAndGet(); + assert decremented >= 0; + + if (rejected == false) { + pollAndSpawn(); + } + }); + + @Override + public boolean isForceExecution() { + return isForceExecution; + } + + @Override + public void onRejection(Exception e) { + logger.trace("[{}] task {} rejected", taskRunnerName, task); + rejected = true; + try { + task.onFailure(e); + } finally { + releasable.close(); + } + } + + @Override + public void onFailure(Exception e) { + // should not happen + logger.error(() -> Strings.format("[%s] task %s failed", taskRunnerName, task), e); + assert false : e; + task.onFailure(e); + } + + @Override + protected void doRun() { + logger.trace("[{}] running task {}", taskRunnerName, task); + task.onResponse(releasable); + } + + @Override + public String toString() { + return task.toString(); + } + }); + } + } + } + + // Each worker thread that runs a task, first needs to get a "free slot" in order to respect maxRunningTasks. + private boolean incrementRunningTasks() { + int preUpdateValue = runningTasks.getAndAccumulate(maxRunningTasks, (v, maxRunning) -> v < maxRunning ? v + 1 : v); + assert preUpdateValue <= maxRunningTasks; + return preUpdateValue < maxRunningTasks; + } + + // exposed for testing + int runningTasks() { + return runningTasks.get(); + } + +} diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java index cc2c8264ab289..96d6b0ed94713 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java @@ -8,14 +8,11 @@ package org.elasticsearch.common.util.concurrent; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.core.Strings; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Releasable; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executor; import java.util.concurrent.PriorityBlockingQueue; -import java.util.concurrent.atomic.AtomicInteger; /** * {@link PrioritizedThrottledTaskRunner} performs the enqueued tasks in the order dictated by the @@ -23,119 +20,70 @@ * that is dequeued to be run, is forked off to the given executor. */ public class PrioritizedThrottledTaskRunner> { - private static final Logger logger = LogManager.getLogger(PrioritizedThrottledTaskRunner.class); - - private final String taskRunnerName; - // The max number of tasks that this runner will schedule to concurrently run on the executor. - private final int maxRunningTasks; - // As we fork off dequeued tasks to the given executor, technically the following counter represents - // the number of the concurrent pollAndSpawn calls currently checking the queue for a task to run. This - // doesn't necessarily correspond to currently running tasks, since a pollAndSpawn could return without - // actually running a task when the queue is empty. - private final AtomicInteger runningTasks = new AtomicInteger(); - private final BlockingQueue tasks = new PriorityBlockingQueue<>(); - private final Executor executor; - public PrioritizedThrottledTaskRunner(final String name, final int maxRunningTasks, final Executor executor) { - assert maxRunningTasks > 0; - this.taskRunnerName = name; - this.maxRunningTasks = maxRunningTasks; - this.executor = executor; - } + private final AbstractThrottledTaskRunner> runner; + private final PriorityBlockingQueue> queue; - public void enqueueTask(final T task) { - logger.trace("[{}] enqueuing task {}", taskRunnerName, task); - tasks.add(task); - // Try to run a task since now there is at least one in the queue. If the maxRunningTasks is - // reached, the task is just enqueued. - pollAndSpawn(); - } + private static class TaskWrapper> + implements + ActionListener, + Comparable> { + + private final T task; + + TaskWrapper(T task) { + this.task = task; + } + + @Override + public int compareTo(TaskWrapper o) { + return task.compareTo(o.task); + } - private void pollAndSpawn() { - // A pollAndSpawn attempts to run a new task. There could be many concurrent pollAndSpawn calls competing - // to get a "free slot", since we attempt to run a new task on every enqueueTask call and every time an - // existing task is finished. - while (incrementRunningTasks()) { - T task = tasks.poll(); - if (task == null) { - logger.trace("[{}] task queue is empty", taskRunnerName); - // We have taken up a "free slot", but there are no tasks in the queue! This could happen each time a worker - // sees an empty queue after running a task. Decrement to give competing pollAndSpawn calls a chance! - int decremented = runningTasks.decrementAndGet(); - assert decremented >= 0; - // We might have blocked all competing pollAndSpawn calls. This could happen for example when - // maxRunningTasks=1 and a task got enqueued just after checking the queue but before decrementing. - // To be sure, return only if the queue is still empty. If the queue is not empty, this might be the - // only pollAndSpawn call in progress, and returning without peeking would risk ending up with a - // non-empty queue and no workers! - if (tasks.peek() == null) break; - } else { - executor.execute(new AbstractRunnable() { - private boolean rejected; // need not be volatile - if we're rejected then that happens-before calling onAfter - - @Override - public boolean isForceExecution() { - return task.isForceExecution(); - } - - @Override - public void onRejection(Exception e) { - logger.trace("[{}] task {} rejected", taskRunnerName, task); - rejected = true; - task.onRejection(e); - } - - @Override - public void onFailure(Exception e) { - logger.trace(() -> Strings.format("[%s] task %s failed", taskRunnerName, task), e); - task.onFailure(e); - } - - @Override - protected void doRun() throws Exception { - logger.trace("[{}] running task {}", taskRunnerName, task); - task.doRun(); - } - - @Override - public void onAfter() { - try { - task.onAfter(); - } finally { - // To avoid missing to run tasks that are enqueued and waiting, we check the queue again once running - // a task is finished. - int decremented = runningTasks.decrementAndGet(); - assert decremented >= 0; - - if (rejected == false) { - pollAndSpawn(); - } - } - } - - @Override - public String toString() { - return task.toString(); - } - }); + @Override + public String toString() { + return task.toString(); + } + + @Override + public void onResponse(Releasable releasable) { + try (releasable) { + task.run(); + } + } + + @Override + public void onFailure(Exception e) { + assert e instanceof EsRejectedExecutionException : e; + try { + task.onRejection(e); + } finally { + task.onAfter(); } } } - // Each worker thread that runs a task, first needs to get a "free slot" in order to respect maxRunningTasks. - private boolean incrementRunningTasks() { - int preUpdateValue = runningTasks.getAndAccumulate(maxRunningTasks, (v, maxRunning) -> v < maxRunning ? v + 1 : v); - assert preUpdateValue <= maxRunningTasks; - return preUpdateValue < maxRunningTasks; + public PrioritizedThrottledTaskRunner(final String name, final int maxRunningTasks, final Executor executor) { + this.queue = new PriorityBlockingQueue<>(); + this.runner = new AbstractThrottledTaskRunner<>(name, maxRunningTasks, executor, queue); + } + + /** + * Submits a task for execution. If there are fewer than {@code maxRunningTasks} tasks currently running then this task is immediately + * submitted to the executor. Otherwise this task is enqueued and will be submitted to the executor in turn on completion of some other + * task. + */ + public void enqueueTask(final T task) { + runner.enqueueTask(new TaskWrapper<>(task)); } // Only use for testing public int runningTasks() { - return runningTasks.get(); + return runner.runningTasks(); } // Only use for testing public int queueSize() { - return tasks.size(); + return queue.size(); } } diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledTaskRunner.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledTaskRunner.java new file mode 100644 index 0000000000000..674e58ee766db --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledTaskRunner.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Releasable; + +import java.util.concurrent.Executor; + +public class ThrottledTaskRunner extends AbstractThrottledTaskRunner> { + // a simple AbstractThrottledTaskRunner which fixes the task type and uses a regular FIFO blocking queue. + public ThrottledTaskRunner(String name, int maxRunningTasks, Executor executor) { + super(name, maxRunningTasks, executor, ConcurrentCollections.newBlockingQueue()); + } +} diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunnerTests.java new file mode 100644 index 0000000000000..0163ad3a82a12 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunnerTests.java @@ -0,0 +1,209 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.test.ESTestCase; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class AbstractThrottledTaskRunnerTests extends ESTestCase { + + private static final ThreadFactory threadFactory = EsExecutors.daemonThreadFactory("test"); + private static final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + private ExecutorService executor; + private int maxThreads; + + @Override + public void setUp() throws Exception { + super.setUp(); + maxThreads = between(1, 10); + executor = EsExecutors.newScaling("test", 1, maxThreads, 0, TimeUnit.MILLISECONDS, false, threadFactory, threadContext); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + terminate(executor); + } + + public void testMultiThreadedEnqueue() throws Exception { + final int maxTasks = randomIntBetween(1, 2 * maxThreads); + final var permits = new Semaphore(maxTasks); + final int totalTasks = randomIntBetween(2 * maxTasks, 10 * maxTasks); + final var latch = new CountDownLatch(totalTasks); + + class TestTask implements ActionListener { + + private final ExecutorService taskExecutor = randomFrom(executor, EsExecutors.DIRECT_EXECUTOR_SERVICE); + + @Override + public void onFailure(Exception e) { + throw new AssertionError(e); + } + + @Override + public void onResponse(Releasable releasable) { + assertTrue(permits.tryAcquire()); + try { + Thread.sleep(between(0, 10)); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + taskExecutor.execute(() -> { + permits.release(); + releasable.close(); + latch.countDown(); + }); + } + } + + final BlockingQueue queue = ConcurrentCollections.newBlockingQueue(); + final AbstractThrottledTaskRunner taskRunner = new AbstractThrottledTaskRunner<>("test", maxTasks, executor, queue); + + final var threadBlocker = new TestBarrier(totalTasks); + for (int i = 0; i < totalTasks; i++) { + new Thread(() -> { + threadBlocker.await(); + taskRunner.enqueueTask(new TestTask()); + assertThat(taskRunner.runningTasks(), lessThanOrEqualTo(maxTasks)); + }).start(); + } + // Eventually all tasks are executed + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertTrue(queue.isEmpty()); + assertTrue(permits.tryAcquire(maxTasks)); + assertNoRunningTasks(taskRunner); + } + + public void testEnqueueSpawnsNewTasksUpToMax() throws Exception { + int maxTasks = randomIntBetween(1, maxThreads); + final int enqueued = maxTasks - 1; // So that it is possible to run at least one more task + final int newTasks = randomIntBetween(1, 10); + + CountDownLatch taskBlocker = new CountDownLatch(1); + CountDownLatch executedCountDown = new CountDownLatch(enqueued + newTasks); + + class TestTask implements ActionListener { + + @Override + public void onFailure(Exception e) { + throw new AssertionError(e); + } + + @Override + public void onResponse(Releasable releasable) { + try { + taskBlocker.await(); + } catch (InterruptedException e) { + throw new AssertionError(e); + } finally { + executedCountDown.countDown(); + releasable.close(); + } + } + } + + final BlockingQueue queue = ConcurrentCollections.newBlockingQueue(); + final AbstractThrottledTaskRunner taskRunner = new AbstractThrottledTaskRunner<>("test", maxTasks, executor, queue); + for (int i = 0; i < enqueued; i++) { + taskRunner.enqueueTask(new TestTask()); + assertThat(taskRunner.runningTasks(), equalTo(i + 1)); + assertTrue(queue.isEmpty()); + } + // Enqueueing one or more new tasks would create only one new running task + for (int i = 0; i < newTasks; i++) { + taskRunner.enqueueTask(new TestTask()); + assertThat(taskRunner.runningTasks(), equalTo(maxTasks)); + assertThat(queue.size(), equalTo(i)); + } + taskBlocker.countDown(); + /// Eventually all tasks are executed + assertTrue(executedCountDown.await(10, TimeUnit.SECONDS)); + assertTrue(queue.isEmpty()); + assertNoRunningTasks(taskRunner); + } + + public void testFailsTasksOnRejectionOrShutdown() throws Exception { + final var executor = randomBoolean() + ? EsExecutors.newScaling("test", 1, maxThreads, 0, TimeUnit.MILLISECONDS, true, threadFactory, threadContext) + : EsExecutors.newFixed("test", maxThreads, between(1, 5), threadFactory, threadContext, false); + + final var totalPermits = between(1, maxThreads * 2); + final var permits = new Semaphore(totalPermits); + final var taskCompleted = new CountDownLatch(between(1, maxThreads * 2)); + final var rejectionCountDown = new CountDownLatch(between(1, maxThreads * 2)); + + class TestTask implements ActionListener { + + @Override + public void onFailure(Exception e) { + rejectionCountDown.countDown(); + permits.release(); + } + + @Override + public void onResponse(Releasable releasable) { + permits.release(); + taskCompleted.countDown(); + releasable.close(); + } + } + + final BlockingQueue queue = ConcurrentCollections.newBlockingQueue(); + final AbstractThrottledTaskRunner taskRunner = new AbstractThrottledTaskRunner<>( + "test", + between(1, maxThreads * 2), + executor, + queue + ); + + final var spawnThread = new Thread(() -> { + try { + while (true) { + assertTrue(permits.tryAcquire(10, TimeUnit.SECONDS)); + taskRunner.enqueueTask(new TestTask()); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + spawnThread.start(); + assertTrue(taskCompleted.await(10, TimeUnit.SECONDS)); + executor.shutdown(); + assertTrue(executor.awaitTermination(30, TimeUnit.SECONDS)); + assertTrue(rejectionCountDown.await(10, TimeUnit.SECONDS)); + spawnThread.interrupt(); + spawnThread.join(); + assertThat(taskRunner.runningTasks(), equalTo(0)); + assertTrue(queue.isEmpty()); + assertTrue(permits.tryAcquire(totalPermits)); + } + + private void assertNoRunningTasks(AbstractThrottledTaskRunner taskRunner) { + final var barrier = new TestBarrier(maxThreads + 1); + for (int i = 0; i < maxThreads; i++) { + executor.execute(barrier::await); + } + barrier.await(); + assertThat(taskRunner.runningTasks(), equalTo(0)); + } + +} diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java index f1955ebcddacb..342ceccd7714d 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java @@ -18,7 +18,6 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutorService; import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadFactory; @@ -83,12 +82,12 @@ public void testMultiThreadedEnqueue() throws Exception { final int maxTasks = randomIntBetween(1, maxThreads); PrioritizedThrottledTaskRunner taskRunner = new PrioritizedThrottledTaskRunner<>("test", maxTasks, executor); final int enqueued = randomIntBetween(2 * maxTasks, 10 * maxTasks); - final var threadBlocker = new CyclicBarrier(enqueued); + final var threadBlocker = new TestBarrier(enqueued); final var executedCountDown = new CountDownLatch(enqueued); for (int i = 0; i < enqueued; i++) { final int taskId = i; new Thread(() -> { - awaitBarrier(threadBlocker); + threadBlocker.await(10, TimeUnit.SECONDS); taskRunner.enqueueTask(new TestTask(() -> { try { Thread.sleep(randomLongBetween(0, 10)); @@ -113,10 +112,10 @@ public void testTasksRunInOrder() throws Exception { final var taskRunner = new PrioritizedThrottledTaskRunner("test", 1, executor); - final var blockBarrier = new CyclicBarrier(2); + final var blockBarrier = new TestBarrier(2); taskRunner.enqueueTask(new TestTask(() -> { - awaitBarrier(blockBarrier); // notify main thread that the runner is blocked - awaitBarrier(blockBarrier); // wait for main thread to finish enqueuing tasks + blockBarrier.await(10, TimeUnit.SECONDS); // notify main thread that the runner is blocked + blockBarrier.await(10, TimeUnit.SECONDS); // wait for main thread to finish enqueuing tasks }, getRandomPriority(), "blocking task")); blockBarrier.await(10, TimeUnit.SECONDS); // wait for blocking task to start executing @@ -124,23 +123,26 @@ public void testTasksRunInOrder() throws Exception { final int enqueued = randomIntBetween(2 * n, 10 * n); List taskPriorities = new ArrayList<>(enqueued); List executedPriorities = new ArrayList<>(enqueued); - final var enqueuedBarrier = new CyclicBarrier(enqueued + 1); + final var enqueuedBarrier = new TestBarrier(enqueued + 1); final var executedCountDown = new CountDownLatch(enqueued); for (int i = 0; i < enqueued; i++) { final int taskId = i; final int priority = getRandomPriority(); taskPriorities.add(priority); new Thread(() -> { - awaitBarrier(enqueuedBarrier); // wait until all threads are ready so the enqueueTask() calls are as concurrent as possible + // wait until all threads are ready so the enqueueTask() calls are as concurrent as possible + enqueuedBarrier.await(10, TimeUnit.SECONDS); taskRunner.enqueueTask(new TestTask(() -> { executedPriorities.add(priority); executedCountDown.countDown(); }, priority, "concurrent enqueued tasks - " + taskId)); - awaitBarrier(enqueuedBarrier); // notify main thread that the task is enqueued + enqueuedBarrier.await(10, TimeUnit.SECONDS); // notify main thread that the task is enqueued }).start(); } - awaitBarrier(enqueuedBarrier); // release all the threads at once - awaitBarrier(enqueuedBarrier); // wait for all threads to confirm the task is enqueued + // release all the threads at once + enqueuedBarrier.await(10, TimeUnit.SECONDS); + // wait for all threads to confirm the task is enqueued + enqueuedBarrier.await(10, TimeUnit.SECONDS); assertThat(taskRunner.queueSize(), equalTo(enqueued)); blockBarrier.await(10, TimeUnit.SECONDS); // notify blocking task that it can continue @@ -245,23 +247,16 @@ private int getRandomPriority() { private void assertNoRunningTasks(PrioritizedThrottledTaskRunner taskRunner) { logger.info("--> ensure that there are no running tasks in the executor. Max number of threads [{}]", maxThreads); - final var barrier = new CyclicBarrier(maxThreads + 1); + final var barrier = new TestBarrier(maxThreads + 1); for (int i = 0; i < maxThreads; i++) { executor.execute(() -> { logger.info("--> await until barrier is released"); - awaitBarrier(barrier); + barrier.await(10, TimeUnit.SECONDS); logger.info("--> the barrier is released"); }); } - awaitBarrier(barrier); + barrier.await(10, TimeUnit.SECONDS); assertThat(taskRunner.runningTasks(), equalTo(0)); } - private static void awaitBarrier(CyclicBarrier barrier) { - try { - barrier.await(10, TimeUnit.SECONDS); - } catch (Exception e) { - throw new AssertionError("unexpected", e); - } - } } diff --git a/test/framework/src/main/java/org/elasticsearch/common/util/concurrent/TestBarrier.java b/test/framework/src/main/java/org/elasticsearch/common/util/concurrent/TestBarrier.java new file mode 100644 index 0000000000000..b77aabe068633 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/common/util/concurrent/TestBarrier.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; + +/** + * A {@link CyclicBarrier} whose {@link #await} methods throw an {@link AssertionError} instead of any checked exceptions, for use in tests. + */ +public class TestBarrier extends CyclicBarrier { + public TestBarrier(int parties) { + super(parties); + } + + @Deprecated(since = "it's probably a mistake to use this in a test because tests should not be waiting forever") + public int awaitForever() { + try { + return super.await(); + } catch (Exception e) { + throw new AssertionError("unexpected", e); + } + } + + /** + * {@link #await} with a 30s timeout. + */ + public int awaitLong() { + return await(30, TimeUnit.SECONDS); + } + + @Override + public int await() { + // in general tests should not wait forever, so this method imposes a default timeout of 10s + return await(10, TimeUnit.SECONDS); + } + + @Override + public int await(long timeout, TimeUnit unit) { + try { + return super.await(timeout, unit); + } catch (Exception e) { + throw new AssertionError("unexpected", e); + } + } +} diff --git a/test/framework/src/test/java/org/elasticsearch/common/util/concurrent/TestBarrierTests.java b/test/framework/src/test/java/org/elasticsearch/common/util/concurrent/TestBarrierTests.java new file mode 100644 index 0000000000000..8d2d4a233dea4 --- /dev/null +++ b/test/framework/src/test/java/org/elasticsearch/common/util/concurrent/TestBarrierTests.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +public class TestBarrierTests extends ESTestCase { + + public void testUsage() throws InterruptedException { + final var barrier = new TestBarrier(2); + final var completed = new AtomicBoolean(); + final var thread = new Thread(() -> { + barrier.await(); + assertTrue(completed.compareAndSet(false, true)); + }); + thread.start(); + assertFalse(completed.get()); + assertThat(randomAwait(barrier), Matchers.oneOf(0, 1)); + thread.join(); + assertTrue(completed.get()); + } + + public void testExceptions() throws InterruptedException { + final var barrier = new TestBarrier(2); + final var completed = new AtomicBoolean(); + final var thread = new Thread(() -> { + expectThrows(AssertionError.class, () -> randomAwait(barrier)); + assertTrue(completed.compareAndSet(false, true)); + }); + thread.start(); + do { + barrier.reset(); + thread.join(10); + } while (thread.isAlive()); + assertTrue(completed.get()); + } + + @SuppressWarnings("deprecation") + private int randomAwait(TestBarrier barrier) { + return switch (between(1, 4)) { + case 1 -> barrier.await(); + case 2 -> barrier.awaitLong(); + case 3 -> barrier.await(10, TimeUnit.SECONDS); + case 4 -> barrier.awaitForever(); + default -> throw new AssertionError(); + }; + } + +} From 0aa0ea04a8881336f36c2582875ad30a9f504ca8 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 6 Feb 2023 08:24:31 +0000 Subject: [PATCH 2/2] CyclicBarrierUtils --- .../AbstractThrottledTaskRunnerTests.java | 11 ++-- .../PrioritizedThrottledTaskRunnerTests.java | 22 +++---- .../util/concurrent/CyclicBarrierUtils.java | 30 ++++++++++ .../common/util/concurrent/TestBarrier.java | 52 ---------------- .../util/concurrent/TestBarrierTests.java | 59 ------------------- 5 files changed, 47 insertions(+), 127 deletions(-) create mode 100644 test/framework/src/main/java/org/elasticsearch/common/util/concurrent/CyclicBarrierUtils.java delete mode 100644 test/framework/src/main/java/org/elasticsearch/common/util/concurrent/TestBarrier.java delete mode 100644 test/framework/src/test/java/org/elasticsearch/common/util/concurrent/TestBarrierTests.java diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunnerTests.java index 0163ad3a82a12..a47000a395ac8 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunnerTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/AbstractThrottledTaskRunnerTests.java @@ -15,6 +15,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutorService; import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadFactory; @@ -78,10 +79,10 @@ public void onResponse(Releasable releasable) { final BlockingQueue queue = ConcurrentCollections.newBlockingQueue(); final AbstractThrottledTaskRunner taskRunner = new AbstractThrottledTaskRunner<>("test", maxTasks, executor, queue); - final var threadBlocker = new TestBarrier(totalTasks); + final var threadBlocker = new CyclicBarrier(totalTasks); for (int i = 0; i < totalTasks; i++) { new Thread(() -> { - threadBlocker.await(); + CyclicBarrierUtils.await(threadBlocker); taskRunner.enqueueTask(new TestTask()); assertThat(taskRunner.runningTasks(), lessThanOrEqualTo(maxTasks)); }).start(); @@ -198,11 +199,11 @@ public void onResponse(Releasable releasable) { } private void assertNoRunningTasks(AbstractThrottledTaskRunner taskRunner) { - final var barrier = new TestBarrier(maxThreads + 1); + final var barrier = new CyclicBarrier(maxThreads + 1); for (int i = 0; i < maxThreads; i++) { - executor.execute(barrier::await); + executor.execute(() -> CyclicBarrierUtils.await(barrier)); } - barrier.await(); + CyclicBarrierUtils.await(barrier); assertThat(taskRunner.runningTasks(), equalTo(0)); } diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java index 01a318f3e5af6..6ded72fd1f625 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java @@ -75,12 +75,12 @@ public void testMultiThreadedEnqueue() throws Exception { final int maxTasks = randomIntBetween(1, maxThreads); PrioritizedThrottledTaskRunner taskRunner = new PrioritizedThrottledTaskRunner<>("test", maxTasks, executor); final int enqueued = randomIntBetween(2 * maxTasks, 10 * maxTasks); - final var threadBlocker = new TestBarrier(enqueued); + final var threadBlocker = new CyclicBarrier(enqueued); final var executedCountDown = new CountDownLatch(enqueued); for (int i = 0; i < enqueued; i++) { final int taskId = i; new Thread(() -> { - threadBlocker.await(10, TimeUnit.SECONDS); + CyclicBarrierUtils.await(threadBlocker); taskRunner.enqueueTask(new TestTask(() -> { try { Thread.sleep(randomLongBetween(0, 10)); @@ -105,10 +105,10 @@ public void testTasksRunInOrder() throws Exception { final var taskRunner = new PrioritizedThrottledTaskRunner("test", 1, executor); - final var blockBarrier = new TestBarrier(2); + final var blockBarrier = new CyclicBarrier(2); taskRunner.enqueueTask(new TestTask(() -> { - blockBarrier.await(10, TimeUnit.SECONDS); // notify main thread that the runner is blocked - blockBarrier.await(10, TimeUnit.SECONDS); // wait for main thread to finish enqueuing tasks + CyclicBarrierUtils.await(blockBarrier); // notify main thread that the runner is blocked + CyclicBarrierUtils.await(blockBarrier); // wait for main thread to finish enqueuing tasks }, getRandomPriority())); blockBarrier.await(10, TimeUnit.SECONDS); // wait for blocking task to start executing @@ -116,7 +116,7 @@ public void testTasksRunInOrder() throws Exception { final int enqueued = randomIntBetween(2 * n, 10 * n); List taskPriorities = new ArrayList<>(enqueued); List executedPriorities = new ArrayList<>(enqueued); - final var enqueuedBarrier = new TestBarrier(enqueued + 1); + final var enqueuedBarrier = new CyclicBarrier(enqueued + 1); final var executedCountDown = new CountDownLatch(enqueued); for (int i = 0; i < enqueued; i++) { final int taskId = i; @@ -124,12 +124,12 @@ public void testTasksRunInOrder() throws Exception { taskPriorities.add(priority); new Thread(() -> { // wait until all threads are ready so the enqueueTask() calls are as concurrent as possible - enqueuedBarrier.await(10, TimeUnit.SECONDS); + CyclicBarrierUtils.await(enqueuedBarrier); taskRunner.enqueueTask(new TestTask(() -> { executedPriorities.add(priority); executedCountDown.countDown(); }, priority)); - enqueuedBarrier.await(10, TimeUnit.SECONDS); // notify main thread that the task is enqueued + CyclicBarrierUtils.await(enqueuedBarrier); // notify main thread that the task is enqueued }).start(); } // release all the threads at once @@ -237,11 +237,11 @@ private int getRandomPriority() { } private void assertNoRunningTasks(PrioritizedThrottledTaskRunner taskRunner) { - final var barrier = new TestBarrier(maxThreads + 1); + final var barrier = new CyclicBarrier(maxThreads + 1); for (int i = 0; i < maxThreads; i++) { - executor.execute(() -> barrier.await(10, TimeUnit.SECONDS)); + executor.execute(() -> CyclicBarrierUtils.await(barrier)); } - barrier.await(10, TimeUnit.SECONDS); + CyclicBarrierUtils.await(barrier); assertThat(taskRunner.runningTasks(), equalTo(0)); } diff --git a/test/framework/src/main/java/org/elasticsearch/common/util/concurrent/CyclicBarrierUtils.java b/test/framework/src/main/java/org/elasticsearch/common/util/concurrent/CyclicBarrierUtils.java new file mode 100644 index 0000000000000..7fd9e11227250 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/common/util/concurrent/CyclicBarrierUtils.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.util.concurrent; + +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; + +public class CyclicBarrierUtils { + private CyclicBarrierUtils() { + // no instances + } + + /** + * Await the given {@link CyclicBarrier}, failing the test after 10s with an {@link AssertionError}. Tests should not wait forever, so + * a timed wait is always appropriate. + */ + public static void await(CyclicBarrier cyclicBarrier) { + try { + cyclicBarrier.await(10, TimeUnit.SECONDS); + } catch (Exception e) { + throw new AssertionError("unexpected", e); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/common/util/concurrent/TestBarrier.java b/test/framework/src/main/java/org/elasticsearch/common/util/concurrent/TestBarrier.java deleted file mode 100644 index b77aabe068633..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/common/util/concurrent/TestBarrier.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.common.util.concurrent; - -import java.util.concurrent.CyclicBarrier; -import java.util.concurrent.TimeUnit; - -/** - * A {@link CyclicBarrier} whose {@link #await} methods throw an {@link AssertionError} instead of any checked exceptions, for use in tests. - */ -public class TestBarrier extends CyclicBarrier { - public TestBarrier(int parties) { - super(parties); - } - - @Deprecated(since = "it's probably a mistake to use this in a test because tests should not be waiting forever") - public int awaitForever() { - try { - return super.await(); - } catch (Exception e) { - throw new AssertionError("unexpected", e); - } - } - - /** - * {@link #await} with a 30s timeout. - */ - public int awaitLong() { - return await(30, TimeUnit.SECONDS); - } - - @Override - public int await() { - // in general tests should not wait forever, so this method imposes a default timeout of 10s - return await(10, TimeUnit.SECONDS); - } - - @Override - public int await(long timeout, TimeUnit unit) { - try { - return super.await(timeout, unit); - } catch (Exception e) { - throw new AssertionError("unexpected", e); - } - } -} diff --git a/test/framework/src/test/java/org/elasticsearch/common/util/concurrent/TestBarrierTests.java b/test/framework/src/test/java/org/elasticsearch/common/util/concurrent/TestBarrierTests.java deleted file mode 100644 index 8d2d4a233dea4..0000000000000 --- a/test/framework/src/test/java/org/elasticsearch/common/util/concurrent/TestBarrierTests.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.common.util.concurrent; - -import org.elasticsearch.test.ESTestCase; -import org.hamcrest.Matchers; - -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; - -public class TestBarrierTests extends ESTestCase { - - public void testUsage() throws InterruptedException { - final var barrier = new TestBarrier(2); - final var completed = new AtomicBoolean(); - final var thread = new Thread(() -> { - barrier.await(); - assertTrue(completed.compareAndSet(false, true)); - }); - thread.start(); - assertFalse(completed.get()); - assertThat(randomAwait(barrier), Matchers.oneOf(0, 1)); - thread.join(); - assertTrue(completed.get()); - } - - public void testExceptions() throws InterruptedException { - final var barrier = new TestBarrier(2); - final var completed = new AtomicBoolean(); - final var thread = new Thread(() -> { - expectThrows(AssertionError.class, () -> randomAwait(barrier)); - assertTrue(completed.compareAndSet(false, true)); - }); - thread.start(); - do { - barrier.reset(); - thread.join(10); - } while (thread.isAlive()); - assertTrue(completed.get()); - } - - @SuppressWarnings("deprecation") - private int randomAwait(TestBarrier barrier) { - return switch (between(1, 4)) { - case 1 -> barrier.await(); - case 2 -> barrier.awaitLong(); - case 3 -> barrier.await(10, TimeUnit.SECONDS); - case 4 -> barrier.awaitForever(); - default -> throw new AssertionError(); - }; - } - -}