Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.bookkeeper.common.util;

import com.google.common.annotations.VisibleForTesting;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -29,6 +30,7 @@
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.LongAdder;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -54,6 +56,11 @@ public class SingleThreadExecutor extends AbstractExecutorService implements Exe
private final LongAdder tasksRejected = new LongAdder();
private final LongAdder tasksFailed = new LongAdder();

private final int maxQueueCapacity;
private static final AtomicIntegerFieldUpdater<SingleThreadExecutor> pendingTaskCountUpdater =
AtomicIntegerFieldUpdater.newUpdater(SingleThreadExecutor.class, "pendingTaskCount");
private volatile int pendingTaskCount = 0;

enum State {
Running,
Shutdown,
Expand All @@ -80,6 +87,8 @@ public SingleThreadExecutor(ThreadFactory tf, int maxQueueCapacity, boolean reje
} else {
this.queue = new GrowableMpScArrayConsumerBlockingQueue<>();
}
this.maxQueueCapacity = maxQueueCapacity;

this.runner = tf.newThread(this);
this.state = State.Running;
this.rejectExecution = rejectExecution;
Expand Down Expand Up @@ -144,6 +153,8 @@ private boolean safeRunTask(Runnable r) {
tasksFailed.increment();
log.error("Error while running task: {}", t.getMessage(), t);
}
} finally {
decrementPendingTaskCount(1);
}

return true;
Expand All @@ -162,7 +173,8 @@ public List<Runnable> shutdownNow() {
this.state = State.Shutdown;
this.runner.interrupt();
List<Runnable> remainingTasks = new ArrayList<>();
queue.drainTo(remainingTasks);
int n = queue.drainTo(remainingTasks);
decrementPendingTaskCount(n);
return remainingTasks;
}

Expand Down Expand Up @@ -204,27 +216,89 @@ public long getFailedTasksCount() {

@Override
public void execute(Runnable r) {
executeRunnableOrList(r, null);
}

@VisibleForTesting
void executeRunnableOrList(Runnable runnable, List<Runnable> runnableList) {
if (state != State.Running) {
throw new RejectedExecutionException("Executor is shutting down");
}

boolean hasSingle = runnable != null;
boolean hasList = runnableList != null && !runnableList.isEmpty();

if (hasSingle == hasList) {
// Both are provided or both are missing
throw new IllegalArgumentException("Provide either 'runnable' or a non-empty 'runnableList', not both.");
}

try {
if (!rejectExecution) {
queue.put(r);
tasksCount.increment();
} else {
if (queue.offer(r)) {
if (hasSingle) {
queue.put(runnable);
tasksCount.increment();
} else {
tasksRejected.increment();
throw new ExecutorRejectedException("Executor queue is full");
for (Runnable task : runnableList) {
queue.put(task);
tasksCount.increment();
}
}
} else {
int count = runnable != null ? 1 : runnableList.size();
incrementPendingTaskCount(count);
boolean success = hasSingle
? queue.offer(runnable)
: queue.addAll(runnableList);
if (success) {
tasksCount.add(count);
} else {
decrementPendingTaskCount(count);
reject();
}
}
} catch (InterruptedException e) {
throw new RejectedExecutionException("Executor thread was interrupted", e);
}
}

private void incrementPendingTaskCount(int count) {
if (maxQueueCapacity <= 0) {
return; // Unlimited capacity
}

if (count < 0) {
throw new IllegalArgumentException("Count must be non-negative");
}

int oldPendingTaskCount = pendingTaskCountUpdater.getAndAccumulate(this, count,
(curr, inc) -> (curr + inc > maxQueueCapacity) ? curr : curr + inc);

if (oldPendingTaskCount + count > maxQueueCapacity) {
reject();
}
}

private void decrementPendingTaskCount(int count) {
if (maxQueueCapacity <= 0) {
return; // Unlimited capacity
}

if (count < 0) {
throw new IllegalArgumentException("Count must be non-negative");
}

int currentPendingCount = pendingTaskCountUpdater.addAndGet(this, -count);
if (log.isDebugEnabled()) {
log.debug("Released {} task(s), current pending count: {}", count, currentPendingCount);
}
}

private void reject() {
tasksRejected.increment();
throw new ExecutorRejectedException("Executor queue is full");
}

public void registerMetrics(StatsLogger statsLogger) {
// Register gauges
statsLogger.scopeLabel("thread", runner.getName())
Expand Down Expand Up @@ -289,6 +363,11 @@ public Number getSample() {
});
}

@VisibleForTesting
int getPendingTaskCount() {
return pendingTaskCountUpdater.get(this);
}

private static class ExecutorRejectedException extends RejectedExecutionException {

private ExecutorRejectedException(String msg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import com.google.common.collect.Lists;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -83,20 +87,23 @@ public void testRejectWhenQueueIsFull() throws Exception {
CountDownLatch startedLatch = new CountDownLatch(1);

for (int i = 0; i < 10; i++) {
int n = i;
ste.execute(() -> {
startedLatch.countDown();

try {
barrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
// ignore
if (n == 0) {
startedLatch.countDown();
} else {
try {
barrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
// ignore
}
}
});

// Wait until the first task is already running in the thread
startedLatch.await();
}

// Wait until the first task is already running in the thread
startedLatch.await();

// Next task should go through, because the runner thread has already pulled out 1 item from the
// queue: the first tasks which is currently stuck
ste.execute(() -> {
Expand All @@ -116,6 +123,52 @@ public void testRejectWhenQueueIsFull() throws Exception {
assertEquals(0, ste.getFailedTasksCount());
}

@Test
public void testRejectWhenDrainToInProgressAndQueueIsEmpty() throws Exception {
@Cleanup("shutdownNow")
SingleThreadExecutor ste = new SingleThreadExecutor(THREAD_FACTORY, 10, true);

CountDownLatch waitedLatch = new CountDownLatch(1);
List<Runnable> tasks = new ArrayList<>();

for (int i = 0; i < 10; i++) {
tasks.add(() -> {
try {
// Block task to simulate an active, long-running task.
waitedLatch.await();
} catch (Exception e) {
// ignored
}
});
}
ste.executeRunnableOrList(null, tasks);

Awaitility.await().pollDelay(1, TimeUnit.SECONDS)
.untilAsserted(() -> assertEquals(10, ste.getPendingTaskCount()));

// Now the queue is really full and should reject tasks.
assertThrows(RejectedExecutionException.class, () -> ste.execute(() -> {
}));

assertEquals(10, ste.getPendingTaskCount());
assertEquals(1, ste.getRejectedTasksCount());
assertEquals(0, ste.getFailedTasksCount());

// Now we can unblock the waited tasks.
waitedLatch.countDown();

// Check the tasks are completed.
Awaitility.await().pollDelay(1, TimeUnit.SECONDS)
.untilAsserted(() -> assertEquals(0, ste.getPendingTaskCount()));

// Invalid cases - should throw IllegalArgumentException.
assertThrows(IllegalArgumentException.class, () -> ste.executeRunnableOrList(null, null));
assertThrows(IllegalArgumentException.class, () -> ste.executeRunnableOrList(null, Collections.emptyList()));
assertThrows(IllegalArgumentException.class, () -> ste.executeRunnableOrList(() -> {
}, Lists.newArrayList(() -> {
})));
}

@Test
public void testBlockWhenQueueIsFull() throws Exception {
@Cleanup("shutdown")
Expand Down