Skip to content

Commit 656fcf4

Browse files
committed
Use ConcurrentSkipListSet with absolute ordering to back work queue
1 parent 9b48074 commit 656fcf4

File tree

1 file changed

+25
-35
lines changed

1 file changed

+25
-35
lines changed

junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorService.java

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323
import java.util.ArrayList;
2424
import java.util.Collection;
2525
import java.util.HashMap;
26+
import java.util.Iterator;
2627
import java.util.List;
2728
import java.util.Map;
28-
import java.util.Queue;
29+
import java.util.Set;
2930
import java.util.concurrent.Callable;
3031
import java.util.concurrent.CompletableFuture;
32+
import java.util.concurrent.ConcurrentSkipListSet;
3133
import java.util.concurrent.ExecutorService;
3234
import java.util.concurrent.Future;
33-
import java.util.concurrent.PriorityBlockingQueue;
3435
import java.util.concurrent.RejectedExecutionException;
3536
import java.util.concurrent.RejectedExecutionHandler;
3637
import java.util.concurrent.Semaphore;
@@ -225,17 +226,16 @@ void processQueueEntries(WorkerLease workerLease, BooleanSupplier doneCondition)
225226
LOGGER.trace(() -> "yielding resource lock");
226227
break;
227228
}
228-
var queueEntries = workQueue.peekAll();
229-
if (queueEntries.isEmpty()) {
229+
if (workQueue.isEmpty()) {
230230
LOGGER.trace(() -> "no queue entries available");
231231
break;
232232
}
233-
processQueueEntries(queueEntries);
233+
processQueueEntries();
234234
}
235235
}
236236

237-
private void processQueueEntries(List<WorkQueue.Entry> queueEntries) {
238-
var queueEntriesByResult = tryToStealWorkWithoutBlocking(queueEntries);
237+
private void processQueueEntries() {
238+
var queueEntriesByResult = tryToStealWorkWithoutBlocking(workQueue);
239239
var queueModified = queueEntriesByResult.containsKey(WorkStealResult.EXECUTED_BY_THIS_WORKER) //
240240
|| queueEntriesByResult.containsKey(WorkStealResult.EXECUTED_BY_DIFFERENT_WORKER);
241241
if (queueModified) {
@@ -288,7 +288,6 @@ void invokeAll(List<? extends TestTask> testTasks) {
288288
private List<WorkQueue.Entry> forkConcurrentChildren(List<? extends TestTask> children,
289289
Consumer<TestTask> isolatedTaskCollector, List<TestTask> sameThreadTasks) {
290290

291-
int index = 0;
292291
List<WorkQueue.Entry> queueEntries = new ArrayList<>(children.size());
293292
for (TestTask child : children) {
294293
if (requiresGlobalReadWriteLock(child)) {
@@ -298,7 +297,7 @@ else if (child.getExecutionMode() == SAME_THREAD) {
298297
sameThreadTasks.add(child);
299298
}
300299
else {
301-
queueEntries.add(WorkQueue.Entry.createWithIndex(child, index++));
300+
queueEntries.add(workQueue.createEntry(child));
302301
}
303302
}
304303

@@ -316,13 +315,11 @@ else if (child.getExecutionMode() == SAME_THREAD) {
316315
}
317316

318317
private Map<WorkStealResult, List<WorkQueue.Entry>> tryToStealWorkWithoutBlocking(
319-
List<WorkQueue.Entry> queueEntries) {
318+
Iterable<WorkQueue.Entry> queueEntries) {
320319

321320
Map<WorkStealResult, List<WorkQueue.Entry>> queueEntriesByResult = new HashMap<>(
322321
WorkStealResult.values().length);
323-
if (!queueEntries.isEmpty()) {
324-
tryToStealWork(queueEntries, BlockingMode.NON_BLOCKING, queueEntriesByResult);
325-
}
322+
tryToStealWork(queueEntries, BlockingMode.NON_BLOCKING, queueEntriesByResult);
326323
return queueEntriesByResult;
327324
}
328325

@@ -334,7 +331,7 @@ private void tryToStealWorkWithBlocking(Map<WorkStealResult, List<WorkQueue.Entr
334331
tryToStealWork(entriesRequiringResourceLocks, BlockingMode.BLOCKING, queueEntriesByResult);
335332
}
336333

337-
private void tryToStealWork(List<WorkQueue.Entry> entries, BlockingMode blocking,
334+
private void tryToStealWork(Iterable<WorkQueue.Entry> entries, BlockingMode blocking,
338335
Map<WorkStealResult, List<WorkQueue.Entry>> queueEntriesByResult) {
339336
for (var entry : entries) {
340337
var state = tryToStealWork(entry, blocking);
@@ -540,16 +537,21 @@ private enum BlockingMode {
540537
NON_BLOCKING, BLOCKING
541538
}
542539

543-
private static class WorkQueue {
544-
545-
private final Queue<Entry> queue = new PriorityBlockingQueue<>();
540+
private static class WorkQueue implements Iterable<WorkQueue.Entry> {
541+
private final AtomicInteger index = new AtomicInteger();
542+
private final Set<Entry> queue = new ConcurrentSkipListSet<>();
546543

547544
Entry add(TestTask task) {
548-
Entry entry = Entry.create(task);
545+
Entry entry = createEntry(task);
549546
LOGGER.trace(() -> "forking: " + entry.task);
550547
return doAdd(entry);
551548
}
552549

550+
Entry createEntry(TestTask task) {
551+
int level = task.getTestDescriptor().getUniqueId().getSegments().size();
552+
return new Entry(task, new CompletableFuture<>(), level, index.getAndIncrement());
553+
}
554+
553555
void addAll(Collection<Entry> entries) {
554556
entries.forEach(this::doAdd);
555557
}
@@ -567,13 +569,6 @@ private Entry doAdd(Entry entry) {
567569
return entry;
568570
}
569571

570-
private List<WorkQueue.Entry> peekAll() {
571-
List<Entry> entries = new ArrayList<>(queue);
572-
// Iteration order isn't the same as queue order.
573-
entries.sort(naturalOrder());
574-
return entries;
575-
}
576-
577572
boolean remove(Entry entry) {
578573
return queue.remove(entry);
579574
}
@@ -582,19 +577,14 @@ boolean isEmpty() {
582577
return queue.isEmpty();
583578
}
584579

580+
@Override
581+
public Iterator<Entry> iterator() {
582+
return queue.iterator();
583+
}
584+
585585
private record Entry(TestTask task, CompletableFuture<@Nullable Void> future, int level, int index)
586586
implements Comparable<Entry> {
587587

588-
static Entry create(TestTask task) {
589-
int level = task.getTestDescriptor().getUniqueId().getSegments().size();
590-
return new Entry(task, new CompletableFuture<>(), level, 0);
591-
}
592-
593-
static Entry createWithIndex(TestTask task, int index) {
594-
int level = task.getTestDescriptor().getUniqueId().getSegments().size();
595-
return new Entry(task, new CompletableFuture<>(), level, index);
596-
}
597-
598588
@SuppressWarnings("FutureReturnValueIgnored")
599589
Entry {
600590
future.whenComplete((__, t) -> {

0 commit comments

Comments
 (0)