diff --git a/bookkeeper-common/src/main/java/org/apache/bookkeeper/common/util/SingleThreadExecutor.java b/bookkeeper-common/src/main/java/org/apache/bookkeeper/common/util/SingleThreadExecutor.java index 3c514ebbdaf..d48d8e3613b 100644 --- a/bookkeeper-common/src/main/java/org/apache/bookkeeper/common/util/SingleThreadExecutor.java +++ b/bookkeeper-common/src/main/java/org/apache/bookkeeper/common/util/SingleThreadExecutor.java @@ -18,6 +18,7 @@ package org.apache.bookkeeper.common.util; +import com.google.common.annotations.VisibleForTesting; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.util.ArrayList; import java.util.List; @@ -29,6 +30,7 @@ import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.LongAdder; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @@ -54,6 +56,11 @@ public class SingleThreadExecutor extends AbstractExecutorService implements Exe private final LongAdder tasksRejected = new LongAdder(); private final LongAdder tasksFailed = new LongAdder(); + private final int maxQueueCapacity; + private static final AtomicIntegerFieldUpdater pendingTaskCountUpdater = + AtomicIntegerFieldUpdater.newUpdater(SingleThreadExecutor.class, "pendingTaskCount"); + private volatile int pendingTaskCount = 0; + enum State { Running, Shutdown, @@ -80,6 +87,8 @@ public SingleThreadExecutor(ThreadFactory tf, int maxQueueCapacity, boolean reje } else { this.queue = new GrowableMpScArrayConsumerBlockingQueue<>(); } + this.maxQueueCapacity = maxQueueCapacity; + this.runner = tf.newThread(this); this.state = State.Running; this.rejectExecution = rejectExecution; @@ -144,6 +153,8 @@ private boolean safeRunTask(Runnable r) { tasksFailed.increment(); log.error("Error while running task: {}", t.getMessage(), t); } + } finally { + decrementPendingTaskCount(1); } return true; @@ -162,7 +173,8 @@ public List shutdownNow() { this.state = State.Shutdown; this.runner.interrupt(); List remainingTasks = new ArrayList<>(); - queue.drainTo(remainingTasks); + int n = queue.drainTo(remainingTasks); + decrementPendingTaskCount(n); return remainingTasks; } @@ -204,20 +216,45 @@ public long getFailedTasksCount() { @Override public void execute(Runnable r) { + executeRunnableOrList(r, null); + } + + @VisibleForTesting + void executeRunnableOrList(Runnable runnable, List runnableList) { if (state != State.Running) { throw new RejectedExecutionException("Executor is shutting down"); } + boolean hasSingle = runnable != null; + boolean hasList = runnableList != null && !runnableList.isEmpty(); + + if (hasSingle == hasList) { + // Both are provided or both are missing + throw new IllegalArgumentException("Provide either 'runnable' or a non-empty 'runnableList', not both."); + } + try { if (!rejectExecution) { - queue.put(r); - tasksCount.increment(); - } else { - if (queue.offer(r)) { + if (hasSingle) { + queue.put(runnable); tasksCount.increment(); } else { - tasksRejected.increment(); - throw new ExecutorRejectedException("Executor queue is full"); + for (Runnable task : runnableList) { + queue.put(task); + tasksCount.increment(); + } + } + } else { + int count = runnable != null ? 1 : runnableList.size(); + incrementPendingTaskCount(count); + boolean success = hasSingle + ? queue.offer(runnable) + : queue.addAll(runnableList); + if (success) { + tasksCount.add(count); + } else { + decrementPendingTaskCount(count); + reject(); } } } catch (InterruptedException e) { @@ -225,6 +262,43 @@ public void execute(Runnable r) { } } + private void incrementPendingTaskCount(int count) { + if (maxQueueCapacity <= 0) { + return; // Unlimited capacity + } + + if (count < 0) { + throw new IllegalArgumentException("Count must be non-negative"); + } + + int oldPendingTaskCount = pendingTaskCountUpdater.getAndAccumulate(this, count, + (curr, inc) -> (curr + inc > maxQueueCapacity) ? curr : curr + inc); + + if (oldPendingTaskCount + count > maxQueueCapacity) { + reject(); + } + } + + private void decrementPendingTaskCount(int count) { + if (maxQueueCapacity <= 0) { + return; // Unlimited capacity + } + + if (count < 0) { + throw new IllegalArgumentException("Count must be non-negative"); + } + + int currentPendingCount = pendingTaskCountUpdater.addAndGet(this, -count); + if (log.isDebugEnabled()) { + log.debug("Released {} task(s), current pending count: {}", count, currentPendingCount); + } + } + + private void reject() { + tasksRejected.increment(); + throw new ExecutorRejectedException("Executor queue is full"); + } + public void registerMetrics(StatsLogger statsLogger) { // Register gauges statsLogger.scopeLabel("thread", runner.getName()) @@ -289,6 +363,11 @@ public Number getSample() { }); } + @VisibleForTesting + int getPendingTaskCount() { + return pendingTaskCountUpdater.get(this); + } + private static class ExecutorRejectedException extends RejectedExecutionException { private ExecutorRejectedException(String msg) { diff --git a/bookkeeper-common/src/test/java/org/apache/bookkeeper/common/util/TestSingleThreadExecutor.java b/bookkeeper-common/src/test/java/org/apache/bookkeeper/common/util/TestSingleThreadExecutor.java index 671318de6e2..ed72704e97a 100644 --- a/bookkeeper-common/src/test/java/org/apache/bookkeeper/common/util/TestSingleThreadExecutor.java +++ b/bookkeeper-common/src/test/java/org/apache/bookkeeper/common/util/TestSingleThreadExecutor.java @@ -20,10 +20,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.collect.Lists; import io.netty.util.concurrent.DefaultThreadFactory; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CountDownLatch; @@ -83,20 +87,23 @@ public void testRejectWhenQueueIsFull() throws Exception { CountDownLatch startedLatch = new CountDownLatch(1); for (int i = 0; i < 10; i++) { + int n = i; ste.execute(() -> { - startedLatch.countDown(); - - try { - barrier.await(); - } catch (InterruptedException | BrokenBarrierException e) { - // ignore + if (n == 0) { + startedLatch.countDown(); + } else { + try { + barrier.await(); + } catch (InterruptedException | BrokenBarrierException e) { + // ignore + } } }); - - // Wait until the first task is already running in the thread - startedLatch.await(); } + // Wait until the first task is already running in the thread + startedLatch.await(); + // Next task should go through, because the runner thread has already pulled out 1 item from the // queue: the first tasks which is currently stuck ste.execute(() -> { @@ -116,6 +123,52 @@ public void testRejectWhenQueueIsFull() throws Exception { assertEquals(0, ste.getFailedTasksCount()); } + @Test + public void testRejectWhenDrainToInProgressAndQueueIsEmpty() throws Exception { + @Cleanup("shutdownNow") + SingleThreadExecutor ste = new SingleThreadExecutor(THREAD_FACTORY, 10, true); + + CountDownLatch waitedLatch = new CountDownLatch(1); + List tasks = new ArrayList<>(); + + for (int i = 0; i < 10; i++) { + tasks.add(() -> { + try { + // Block task to simulate an active, long-running task. + waitedLatch.await(); + } catch (Exception e) { + // ignored + } + }); + } + ste.executeRunnableOrList(null, tasks); + + Awaitility.await().pollDelay(1, TimeUnit.SECONDS) + .untilAsserted(() -> assertEquals(10, ste.getPendingTaskCount())); + + // Now the queue is really full and should reject tasks. + assertThrows(RejectedExecutionException.class, () -> ste.execute(() -> { + })); + + assertEquals(10, ste.getPendingTaskCount()); + assertEquals(1, ste.getRejectedTasksCount()); + assertEquals(0, ste.getFailedTasksCount()); + + // Now we can unblock the waited tasks. + waitedLatch.countDown(); + + // Check the tasks are completed. + Awaitility.await().pollDelay(1, TimeUnit.SECONDS) + .untilAsserted(() -> assertEquals(0, ste.getPendingTaskCount())); + + // Invalid cases - should throw IllegalArgumentException. + assertThrows(IllegalArgumentException.class, () -> ste.executeRunnableOrList(null, null)); + assertThrows(IllegalArgumentException.class, () -> ste.executeRunnableOrList(null, Collections.emptyList())); + assertThrows(IllegalArgumentException.class, () -> ste.executeRunnableOrList(() -> { + }, Lists.newArrayList(() -> { + }))); + } + @Test public void testBlockWhenQueueIsFull() throws Exception { @Cleanup("shutdown")