Skip to content

Commit 1196b17

Browse files
committed
Hold back one task for the current worker when forking
1 parent a545da0 commit 1196b17

File tree

1 file changed

+38
-15
lines changed

1 file changed

+38
-15
lines changed

junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorService.java

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ public class ConcurrentHierarchicalTestExecutorService implements HierarchicalTe
5656

5757
private final WorkQueue workQueue = new WorkQueue();
5858
private final ExecutorService threadPool;
59+
private final int parallelism;
5960
private final WorkerLeaseManager workerLeaseManager;
6061

6162
public ConcurrentHierarchicalTestExecutorService(ConfigurationParameters configurationParameters) {
@@ -70,7 +71,8 @@ public ConcurrentHierarchicalTestExecutorService(ParallelExecutionConfiguration
7071
ThreadFactory threadFactory = new WorkerThreadFactory(classLoader);
7172
threadPool = new ThreadPoolExecutor(configuration.getCorePoolSize(), configuration.getMaxPoolSize(),
7273
configuration.getKeepAliveSeconds(), SECONDS, new SynchronousQueue<>(), threadFactory);
73-
workerLeaseManager = new WorkerLeaseManager(configuration.getParallelism());
74+
parallelism = configuration.getParallelism();
75+
workerLeaseManager = new WorkerLeaseManager(parallelism);
7476
LOGGER.trace(() -> "initialized thread pool for parallelism of " + configuration.getParallelism());
7577
}
7678

@@ -110,6 +112,14 @@ private WorkQueue.Entry enqueue(TestTask testTask) {
110112
return entry;
111113
}
112114

115+
private void forkAll(Collection<WorkQueue.Entry> entries) {
116+
workQueue.addAll(entries);
117+
// start at most (parallelism - 1) new workers as this method is called from a worker thread holding a lease
118+
for (int i = 1; i < parallelism; i++) {
119+
maybeStartWorker();
120+
}
121+
}
122+
113123
private void maybeStartWorker() {
114124
if (threadPool.isShutdown() || !workerLeaseManager.isLeaseAvailable() || workQueue.isEmpty()) {
115125
return;
@@ -265,32 +275,35 @@ void invokeAll(List<? extends TestTask> testTasks) {
265275

266276
List<TestTask> isolatedTasks = new ArrayList<>(testTasks.size());
267277
List<TestTask> sameThreadTasks = new ArrayList<>(testTasks.size());
268-
var queueEntries = forkConcurrentChildren(testTasks, isolatedTasks::add, sameThreadTasks::add);
278+
var queueEntries = forkConcurrentChildren(testTasks, isolatedTasks::add, sameThreadTasks);
269279
executeAll(sameThreadTasks);
270280
var remainingForkedChildren = stealWork(queueEntries);
271281
waitFor(remainingForkedChildren);
272282
executeAll(isolatedTasks);
273283
}
274284

275285
private Collection<WorkQueue.Entry> forkConcurrentChildren(List<? extends TestTask> children,
276-
Consumer<TestTask> isolatedTaskCollector, Consumer<TestTask> sameThreadTaskCollector) {
277-
278-
if (children.isEmpty()) {
279-
return List.of();
280-
}
286+
Consumer<TestTask> isolatedTaskCollector, List<TestTask> sameThreadTasks) {
281287

282288
Queue<WorkQueue.Entry> queueEntries = new PriorityQueue<>(children.size(), reverseOrder());
283289
for (TestTask child : children) {
284290
if (requiresGlobalReadWriteLock(child)) {
285291
isolatedTaskCollector.accept(child);
286292
}
287293
else if (child.getExecutionMode() == SAME_THREAD) {
288-
sameThreadTaskCollector.accept(child);
294+
sameThreadTasks.add(child);
289295
}
290296
else {
291-
queueEntries.add(enqueue(child));
297+
queueEntries.add(WorkQueue.Entry.create(child));
292298
}
293299
}
300+
if (!queueEntries.isEmpty()) {
301+
if (sameThreadTasks.isEmpty()) {
302+
// hold back one task for this thread
303+
sameThreadTasks.add(queueEntries.poll().task);
304+
}
305+
forkAll(queueEntries);
306+
}
294307
return queueEntries;
295308
}
296309

@@ -313,7 +326,7 @@ private List<WorkQueue.Entry> stealWork(Collection<WorkQueue.Entry> queueEntries
313326
LOGGER.trace(() -> "stole work: " + entry);
314327
var executed = tryExecute(entry);
315328
if (!executed) {
316-
workQueue.add(entry);
329+
workQueue.reAdd(entry);
317330
concurrentlyExecutingChildren.add(entry);
318331
}
319332
}
@@ -354,7 +367,7 @@ private void executeAll(List<? extends TestTask> children) {
354367
if (children.isEmpty()) {
355368
return;
356369
}
357-
LOGGER.trace(() -> "running %d SAME_THREAD children".formatted(children.size()));
370+
LOGGER.trace(() -> "running %d children directly".formatted(children.size()));
358371
if (children.size() == 1) {
359372
executeTask(children.get(0));
360373
return;
@@ -426,12 +439,16 @@ private static class WorkQueue {
426439
private final Queue<Entry> queue = new PriorityBlockingQueue<>();
427440

428441
Entry add(TestTask task) {
429-
LOGGER.trace(() -> "forking: " + task);
430-
int level = task.getTestDescriptor().getUniqueId().getSegments().size();
431-
return doAdd(new Entry(task, new CompletableFuture<>(), level, 0));
442+
Entry entry = Entry.create(task);
443+
LOGGER.trace(() -> "forking: " + entry.task);
444+
return doAdd(entry);
445+
}
446+
447+
void addAll(Collection<Entry> entries) {
448+
entries.forEach(this::doAdd);
432449
}
433450

434-
void add(Entry entry) {
451+
void reAdd(Entry entry) {
435452
LOGGER.trace(() -> "re-enqueuing: " + entry.task);
436453
doAdd(entry.incrementAttempts());
437454
}
@@ -459,6 +476,12 @@ boolean isEmpty() {
459476

460477
private record Entry(TestTask task, CompletableFuture<@Nullable Void> future, int level, int attempts)
461478
implements Comparable<Entry> {
479+
480+
static Entry create(TestTask task) {
481+
int level = task.getTestDescriptor().getUniqueId().getSegments().size();
482+
return new Entry(task, new CompletableFuture<>(), level, 0);
483+
}
484+
462485
@SuppressWarnings("FutureReturnValueIgnored")
463486
Entry {
464487
future.whenComplete((__, t) -> {

0 commit comments

Comments
 (0)