Skip to content

Commit 39e7a91

Browse files
authored
fix: throw reject when SingleThreadExecutor drainTo in progress and queue is empty (#4488)
Fix Issue #4465 Motivaction In SingleThreadExecutor, the runner drains all tasks from the queue into localTasks. Although the queue is empty in memory at this point, the tasks are still pending execution in localTasks—so logically, the queue is still "full." Calling execute() during this phase should not enqueue a new runnable into the queue, as doing so would exceed the intended capacity. This can lead to increased memory usage and potential OutOfMemory (OOM) issues. Changes To address this, we introduce a variable to track the total number of pending runnables. This counter is used to control whether a new task should be added: The counter is incremented when a runnable is added to the queue. The counter is decremented when a runnable is actually executed. This ensures accurate tracking of pending tasks and prevents overfilling the logical task queue.
1 parent f42c915 commit 39e7a91

File tree

2 files changed

+148
-16
lines changed

2 files changed

+148
-16
lines changed

bookkeeper-common/src/main/java/org/apache/bookkeeper/common/util/SingleThreadExecutor.java

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package org.apache.bookkeeper.common.util;
2020

21+
import com.google.common.annotations.VisibleForTesting;
2122
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
2223
import java.util.ArrayList;
2324
import java.util.List;
@@ -29,6 +30,7 @@
2930
import java.util.concurrent.RejectedExecutionException;
3031
import java.util.concurrent.ThreadFactory;
3132
import java.util.concurrent.TimeUnit;
33+
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
3234
import java.util.concurrent.atomic.LongAdder;
3335
import lombok.SneakyThrows;
3436
import lombok.extern.slf4j.Slf4j;
@@ -54,6 +56,11 @@ public class SingleThreadExecutor extends AbstractExecutorService implements Exe
5456
private final LongAdder tasksRejected = new LongAdder();
5557
private final LongAdder tasksFailed = new LongAdder();
5658

59+
private final int maxQueueCapacity;
60+
private static final AtomicIntegerFieldUpdater<SingleThreadExecutor> pendingTaskCountUpdater =
61+
AtomicIntegerFieldUpdater.newUpdater(SingleThreadExecutor.class, "pendingTaskCount");
62+
private volatile int pendingTaskCount = 0;
63+
5764
enum State {
5865
Running,
5966
Shutdown,
@@ -80,6 +87,8 @@ public SingleThreadExecutor(ThreadFactory tf, int maxQueueCapacity, boolean reje
8087
} else {
8188
this.queue = new GrowableMpScArrayConsumerBlockingQueue<>();
8289
}
90+
this.maxQueueCapacity = maxQueueCapacity;
91+
8392
this.runner = tf.newThread(this);
8493
this.state = State.Running;
8594
this.rejectExecution = rejectExecution;
@@ -144,6 +153,8 @@ private boolean safeRunTask(Runnable r) {
144153
tasksFailed.increment();
145154
log.error("Error while running task: {}", t.getMessage(), t);
146155
}
156+
} finally {
157+
decrementPendingTaskCount(1);
147158
}
148159

149160
return true;
@@ -162,7 +173,8 @@ public List<Runnable> shutdownNow() {
162173
this.state = State.Shutdown;
163174
this.runner.interrupt();
164175
List<Runnable> remainingTasks = new ArrayList<>();
165-
queue.drainTo(remainingTasks);
176+
int n = queue.drainTo(remainingTasks);
177+
decrementPendingTaskCount(n);
166178
return remainingTasks;
167179
}
168180

@@ -204,27 +216,89 @@ public long getFailedTasksCount() {
204216

205217
@Override
206218
public void execute(Runnable r) {
219+
executeRunnableOrList(r, null);
220+
}
221+
222+
@VisibleForTesting
223+
void executeRunnableOrList(Runnable runnable, List<Runnable> runnableList) {
207224
if (state != State.Running) {
208225
throw new RejectedExecutionException("Executor is shutting down");
209226
}
210227

228+
boolean hasSingle = runnable != null;
229+
boolean hasList = runnableList != null && !runnableList.isEmpty();
230+
231+
if (hasSingle == hasList) {
232+
// Both are provided or both are missing
233+
throw new IllegalArgumentException("Provide either 'runnable' or a non-empty 'runnableList', not both.");
234+
}
235+
211236
try {
212237
if (!rejectExecution) {
213-
queue.put(r);
214-
tasksCount.increment();
215-
} else {
216-
if (queue.offer(r)) {
238+
if (hasSingle) {
239+
queue.put(runnable);
217240
tasksCount.increment();
218241
} else {
219-
tasksRejected.increment();
220-
throw new ExecutorRejectedException("Executor queue is full");
242+
for (Runnable task : runnableList) {
243+
queue.put(task);
244+
tasksCount.increment();
245+
}
246+
}
247+
} else {
248+
int count = runnable != null ? 1 : runnableList.size();
249+
incrementPendingTaskCount(count);
250+
boolean success = hasSingle
251+
? queue.offer(runnable)
252+
: queue.addAll(runnableList);
253+
if (success) {
254+
tasksCount.add(count);
255+
} else {
256+
decrementPendingTaskCount(count);
257+
reject();
221258
}
222259
}
223260
} catch (InterruptedException e) {
224261
throw new RejectedExecutionException("Executor thread was interrupted", e);
225262
}
226263
}
227264

265+
private void incrementPendingTaskCount(int count) {
266+
if (maxQueueCapacity <= 0) {
267+
return; // Unlimited capacity
268+
}
269+
270+
if (count < 0) {
271+
throw new IllegalArgumentException("Count must be non-negative");
272+
}
273+
274+
int oldPendingTaskCount = pendingTaskCountUpdater.getAndAccumulate(this, count,
275+
(curr, inc) -> (curr + inc > maxQueueCapacity) ? curr : curr + inc);
276+
277+
if (oldPendingTaskCount + count > maxQueueCapacity) {
278+
reject();
279+
}
280+
}
281+
282+
private void decrementPendingTaskCount(int count) {
283+
if (maxQueueCapacity <= 0) {
284+
return; // Unlimited capacity
285+
}
286+
287+
if (count < 0) {
288+
throw new IllegalArgumentException("Count must be non-negative");
289+
}
290+
291+
int currentPendingCount = pendingTaskCountUpdater.addAndGet(this, -count);
292+
if (log.isDebugEnabled()) {
293+
log.debug("Released {} task(s), current pending count: {}", count, currentPendingCount);
294+
}
295+
}
296+
297+
private void reject() {
298+
tasksRejected.increment();
299+
throw new ExecutorRejectedException("Executor queue is full");
300+
}
301+
228302
public void registerMetrics(StatsLogger statsLogger) {
229303
// Register gauges
230304
statsLogger.scopeLabel("thread", runner.getName())
@@ -289,6 +363,11 @@ public Number getSample() {
289363
});
290364
}
291365

366+
@VisibleForTesting
367+
int getPendingTaskCount() {
368+
return pendingTaskCountUpdater.get(this);
369+
}
370+
292371
private static class ExecutorRejectedException extends RejectedExecutionException {
293372

294373
private ExecutorRejectedException(String msg) {

bookkeeper-common/src/test/java/org/apache/bookkeeper/common/util/TestSingleThreadExecutor.java

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020

2121
import static org.junit.Assert.assertEquals;
2222
import static org.junit.Assert.assertFalse;
23+
import static org.junit.Assert.assertThrows;
2324
import static org.junit.Assert.assertTrue;
2425
import static org.junit.Assert.fail;
2526

27+
import com.google.common.collect.Lists;
2628
import io.netty.util.concurrent.DefaultThreadFactory;
29+
import java.util.ArrayList;
30+
import java.util.Collections;
2731
import java.util.List;
2832
import java.util.concurrent.BrokenBarrierException;
2933
import java.util.concurrent.CountDownLatch;
@@ -83,20 +87,23 @@ public void testRejectWhenQueueIsFull() throws Exception {
8387
CountDownLatch startedLatch = new CountDownLatch(1);
8488

8589
for (int i = 0; i < 10; i++) {
90+
int n = i;
8691
ste.execute(() -> {
87-
startedLatch.countDown();
88-
89-
try {
90-
barrier.await();
91-
} catch (InterruptedException | BrokenBarrierException e) {
92-
// ignore
92+
if (n == 0) {
93+
startedLatch.countDown();
94+
} else {
95+
try {
96+
barrier.await();
97+
} catch (InterruptedException | BrokenBarrierException e) {
98+
// ignore
99+
}
93100
}
94101
});
95-
96-
// Wait until the first task is already running in the thread
97-
startedLatch.await();
98102
}
99103

104+
// Wait until the first task is already running in the thread
105+
startedLatch.await();
106+
100107
// Next task should go through, because the runner thread has already pulled out 1 item from the
101108
// queue: the first tasks which is currently stuck
102109
ste.execute(() -> {
@@ -116,6 +123,52 @@ public void testRejectWhenQueueIsFull() throws Exception {
116123
assertEquals(0, ste.getFailedTasksCount());
117124
}
118125

126+
@Test
127+
public void testRejectWhenDrainToInProgressAndQueueIsEmpty() throws Exception {
128+
@Cleanup("shutdownNow")
129+
SingleThreadExecutor ste = new SingleThreadExecutor(THREAD_FACTORY, 10, true);
130+
131+
CountDownLatch waitedLatch = new CountDownLatch(1);
132+
List<Runnable> tasks = new ArrayList<>();
133+
134+
for (int i = 0; i < 10; i++) {
135+
tasks.add(() -> {
136+
try {
137+
// Block task to simulate an active, long-running task.
138+
waitedLatch.await();
139+
} catch (Exception e) {
140+
// ignored
141+
}
142+
});
143+
}
144+
ste.executeRunnableOrList(null, tasks);
145+
146+
Awaitility.await().pollDelay(1, TimeUnit.SECONDS)
147+
.untilAsserted(() -> assertEquals(10, ste.getPendingTaskCount()));
148+
149+
// Now the queue is really full and should reject tasks.
150+
assertThrows(RejectedExecutionException.class, () -> ste.execute(() -> {
151+
}));
152+
153+
assertEquals(10, ste.getPendingTaskCount());
154+
assertEquals(1, ste.getRejectedTasksCount());
155+
assertEquals(0, ste.getFailedTasksCount());
156+
157+
// Now we can unblock the waited tasks.
158+
waitedLatch.countDown();
159+
160+
// Check the tasks are completed.
161+
Awaitility.await().pollDelay(1, TimeUnit.SECONDS)
162+
.untilAsserted(() -> assertEquals(0, ste.getPendingTaskCount()));
163+
164+
// Invalid cases - should throw IllegalArgumentException.
165+
assertThrows(IllegalArgumentException.class, () -> ste.executeRunnableOrList(null, null));
166+
assertThrows(IllegalArgumentException.class, () -> ste.executeRunnableOrList(null, Collections.emptyList()));
167+
assertThrows(IllegalArgumentException.class, () -> ste.executeRunnableOrList(() -> {
168+
}, Lists.newArrayList(() -> {
169+
})));
170+
}
171+
119172
@Test
120173
public void testBlockWhenQueueIsFull() throws Exception {
121174
@Cleanup("shutdown")

0 commit comments

Comments
 (0)