|
10 | 10 |
|
11 | 11 | package org.junit.jupiter.engine.execution; |
12 | 12 |
|
| 13 | +import static java.util.stream.Collectors.toList; |
| 14 | +import static org.assertj.core.api.Assertions.assertThat; |
13 | 15 | import static org.junit.jupiter.api.Assertions.assertEquals; |
14 | 16 | import static org.junit.jupiter.api.Assertions.assertNotEquals; |
15 | 17 | import static org.junit.jupiter.api.Assertions.assertNull; |
16 | 18 | import static org.junit.jupiter.api.Assertions.assertThrows; |
17 | 19 |
|
| 20 | +import java.util.ArrayList; |
| 21 | +import java.util.List; |
| 22 | +import java.util.concurrent.CompletableFuture; |
| 23 | +import java.util.concurrent.CountDownLatch; |
| 24 | +import java.util.concurrent.ExecutorService; |
| 25 | +import java.util.concurrent.Executors; |
| 26 | +import java.util.concurrent.atomic.AtomicInteger; |
18 | 27 | import java.util.function.Function; |
| 28 | +import java.util.function.Supplier; |
19 | 29 |
|
20 | 30 | import org.junit.jupiter.api.Nested; |
21 | | -import org.junit.jupiter.api.RepeatedTest; |
22 | 31 | import org.junit.jupiter.api.Test; |
23 | 32 | import org.junit.jupiter.api.extension.ExtensionContext.Namespace; |
24 | 33 | import org.junit.jupiter.api.extension.ExtensionContextException; |
@@ -290,18 +299,41 @@ void removeNullValueWithTypeSafety() { |
290 | 299 | assertNull(store.get(namespace, key)); |
291 | 300 | } |
292 | 301 |
|
293 | | - @RepeatedTest(23) |
294 | | - void simulateRaceConditionInGetOrComputeIfAbsent() throws Exception { |
| 302 | + @Test |
| 303 | + void simulateRaceConditionInGetOrComputeIfAbsent() { |
| 304 | + int threads = 10; |
| 305 | + AtomicInteger counter = new AtomicInteger(); |
295 | 306 | ExtensionValuesStore localStore = new ExtensionValuesStore(null); |
296 | | - Thread t1 = new Thread(() -> localStore.getOrComputeIfAbsent(namespace, key, key -> value)); |
297 | | - Thread t2 = new Thread(() -> localStore.getOrComputeIfAbsent(namespace, key, key -> value)); |
298 | | - t1.start(); |
299 | | - t2.start(); |
300 | | - Thread.yield(); |
301 | | - localStore.getOrComputeIfAbsent(namespace, key, key -> value); // use current thread as well |
302 | | - t1.join(); |
303 | | - t2.join(); |
304 | | - assertEquals(value, localStore.get(namespace, key)); |
| 307 | + |
| 308 | + List<Object> values = executeConcurrently(threads, // |
| 309 | + () -> localStore.getOrComputeIfAbsent(namespace, key, it -> counter.incrementAndGet())); |
| 310 | + |
| 311 | + assertEquals(1, counter.get()); |
| 312 | + assertThat(values).hasSize(threads).containsOnly(1); |
| 313 | + } |
| 314 | + } |
| 315 | + |
| 316 | + private <T> List<T> executeConcurrently(int threads, Supplier<T> supplier) { |
| 317 | + ExecutorService executorService = Executors.newFixedThreadPool(threads); |
| 318 | + try { |
| 319 | + CountDownLatch latch = new CountDownLatch(threads); |
| 320 | + List<CompletableFuture<T>> futures = new ArrayList<>(); |
| 321 | + for (int i = 0; i < threads; i++) { |
| 322 | + futures.add(CompletableFuture.supplyAsync(() -> { |
| 323 | + latch.countDown(); |
| 324 | + try { |
| 325 | + latch.await(); |
| 326 | + } |
| 327 | + catch (InterruptedException e) { |
| 328 | + Thread.currentThread().interrupt(); |
| 329 | + } |
| 330 | + return supplier.get(); |
| 331 | + }, executorService)); |
| 332 | + } |
| 333 | + return futures.stream().map(CompletableFuture::join).collect(toList()); |
| 334 | + } |
| 335 | + finally { |
| 336 | + executorService.shutdown(); |
305 | 337 | } |
306 | 338 | } |
307 | 339 |
|
|
0 commit comments