Skip to content

Commit a1fb1fe

Browse files
authored
fix: improve thread safety around problem changes (#1439)
1 parent f8030ca commit a1fb1fe

File tree

7 files changed

+206
-47
lines changed

7 files changed

+206
-47
lines changed

build/build-parent/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
<version.org.apache.commons.math3>3.6.1</version.org.apache.commons.math3>
2626
<version.org.apache.logging.log4j>2.24.3</version.org.apache.logging.log4j>
2727
<version.org.assertj>3.27.3</version.org.assertj>
28+
<version.org.awaitility>4.3.0</version.org.awaitility>
2829
<version.org.freemarker>2.3.34</version.org.freemarker>
2930
<version.org.jspecify>1.0.0</version.org.jspecify>
3031
<version.org.openrewrite.recipe>3.3.0</version.org.openrewrite.recipe>
@@ -116,6 +117,12 @@
116117
<artifactId>assertj-core</artifactId>
117118
<version>${version.org.assertj}</version>
118119
</dependency>
120+
<dependency>
121+
<groupId>org.awaitility</groupId>
122+
<artifactId>awaitility</artifactId>
123+
<version>${version.org.awaitility}</version>
124+
<scope>test</scope>
125+
</dependency>
119126
<dependency>
120127
<groupId>org.freemarker</groupId>
121128
<artifactId>freemarker</artifactId>

core/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@
104104
<artifactId>assertj-core</artifactId>
105105
<scope>test</scope>
106106
</dependency>
107+
<dependency>
108+
<groupId>org.awaitility</groupId>
109+
<artifactId>awaitility</artifactId>
110+
<scope>test</scope>
111+
</dependency>
107112
<dependency>
108113
<groupId>org.mockito</groupId>
109114
<artifactId>mockito-core</artifactId>

core/src/main/java/ai/timefold/solver/core/api/solver/change/ProblemChange.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import ai.timefold.solver.core.api.solver.event.BestSolutionChangedEvent;
99
import ai.timefold.solver.core.impl.heuristic.move.Move;
1010

11-
import org.jspecify.annotations.NonNull;
11+
import org.jspecify.annotations.NullMarked;
1212

1313
/**
1414
* A ProblemChange represents a change in one or more {@link PlanningEntity planning entities} or problem facts
@@ -77,6 +77,7 @@
7777
* @param <Solution_> the solution type, the class with the {@link PlanningSolution} annotation
7878
*/
7979
@FunctionalInterface
80+
@NullMarked
8081
public interface ProblemChange<Solution_> {
8182

8283
/**
@@ -87,5 +88,5 @@ public interface ProblemChange<Solution_> {
8788
* (and {@link PlanningEntity planning entities}) to change
8889
* @param problemChangeDirector {@link ProblemChangeDirector} to perform the change through
8990
*/
90-
void doChange(@NonNull Solution_ workingSolution, @NonNull ProblemChangeDirector problemChangeDirector);
91+
void doChange(Solution_ workingSolution, ProblemChangeDirector problemChangeDirector);
9192
}

core/src/main/java/ai/timefold/solver/core/impl/solver/BestSolutionHolder.java

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.concurrent.CompletableFuture;
1111
import java.util.concurrent.atomic.AtomicReference;
1212
import java.util.function.BooleanSupplier;
13+
import java.util.function.UnaryOperator;
1314

1415
import ai.timefold.solver.core.api.solver.Solver;
1516
import ai.timefold.solver.core.api.solver.change.ProblemChange;
@@ -20,7 +21,7 @@
2021

2122
/**
2223
* The goal of this class is to register problem changes and best solutions in a thread-safe way.
23-
* Problem changes are {@link #addProblemChange(Solver, ProblemChange) put in a queue}
24+
* Problem changes are {@link #addProblemChange(Solver, List) put in a queue}
2425
* and later associated with the best solution which contains them.
2526
* The best solution is associated with a version number
2627
* that is incremented each time a {@link #set new best solution is set}.
@@ -29,26 +30,24 @@
2930
*
3031
* <p>
3132
* This class needs to be thread-safe.
32-
* Due to complicated interactions between the solver, solver manager and problem changes,
33-
* it is best if we avoid explicit locking here,
34-
* reducing cognitive complexity of the whole system.
35-
* The core idea being to never modify the same data structure from multiple threads;
36-
* instead, we replace the data structure with a new one atomically.
37-
* The code contains comments throughout the class that explain the reasoning behind the design.
3833
*
3934
* @param <Solution_>
4035
*/
4136
@NullMarked
4237
final class BestSolutionHolder<Solution_> {
4338

44-
private final AtomicReference<@Nullable VersionedBestSolution<Solution_>> versionedBestSolutionRef =
45-
new AtomicReference<>();
46-
private final AtomicReference<SortedMap<BigInteger, List<CompletableFuture<Void>>>> problemChangesPerVersionRef =
47-
new AtomicReference<>(createNewProblemChangesMap());
39+
private final AtomicReference<BigInteger> lastProcessedVersion = new AtomicReference<>(BigInteger.valueOf(-1));
40+
41+
// These references are non-final and being accessed from multiple threads,
42+
// therefore they need to be volatile and all access synchronized.
43+
// Both the map and the best solution are based on the current version,
44+
// and therefore access to both needs to be guarded by the same lock.
4845
// The version is BigInteger to avoid long overflow.
4946
// The solver can run potentially forever, so long overflow is a (remote) possibility.
50-
private final AtomicReference<BigInteger> currentVersion = new AtomicReference<>(BigInteger.ZERO);
51-
private final AtomicReference<BigInteger> lastProcessedVersion = new AtomicReference<>(BigInteger.valueOf(-1));
47+
private volatile SortedMap<BigInteger, List<CompletableFuture<Void>>> problemChangesPerVersionMap =
48+
createNewProblemChangesMap();
49+
private volatile @Nullable VersionedBestSolution<Solution_> versionedBestSolution = null;
50+
private volatile BigInteger currentVersion = BigInteger.ZERO;
5251

5352
private static SortedMap<BigInteger, List<CompletableFuture<Void>>> createNewProblemChangesMap() {
5453
return createNewProblemChangesMap(Collections.emptySortedMap());
@@ -59,8 +58,8 @@ private static SortedMap<BigInteger, List<CompletableFuture<Void>>> createNewPro
5958
return new TreeMap<>(map);
6059
}
6160

62-
boolean isEmpty() {
63-
return versionedBestSolutionRef.get() == null;
61+
synchronized boolean isEmpty() {
62+
return this.versionedBestSolution == null;
6463
}
6564

6665
/**
@@ -69,12 +68,12 @@ boolean isEmpty() {
6968
*/
7069
@Nullable
7170
BestSolutionContainingProblemChanges<Solution_> take() {
72-
var versionedBestSolution = versionedBestSolutionRef.getAndSet(null);
73-
if (versionedBestSolution == null) {
71+
var latestVersionedBestSolution = resetVersionedBestSolution();
72+
if (latestVersionedBestSolution == null) {
7473
return null;
7574
}
7675

77-
var bestSolutionVersion = versionedBestSolution.version();
76+
var bestSolutionVersion = latestVersionedBestSolution.version();
7877
var latestProcessedVersion = this.lastProcessedVersion.getAndUpdate(bestSolutionVersion::max);
7978
if (latestProcessedVersion.compareTo(bestSolutionVersion) > 0) {
8079
// Corner case: The best solution has already been taken,
@@ -84,7 +83,7 @@ BestSolutionContainingProblemChanges<Solution_> take() {
8483
return null;
8584
}
8685
// The map is replaced by a map containing only the problem changes that are not contained in the best solution.
87-
// This is done atomically, so no other thread can access the old map anymore.
86+
// This is fully synchronized, so no other thread can access the old map anymore.
8887
// The old map can then be processed by the current thread without synchronization.
8988
// The copying of maps is possibly expensive, but due to the nature of problem changes,
9089
// we do not expect the map to ever get too big.
@@ -93,7 +92,7 @@ BestSolutionContainingProblemChanges<Solution_> take() {
9392
// The solver also finds new best solutions, which regularly trims the size of the map as well.
9493
var boundaryVersion = bestSolutionVersion.add(BigInteger.ONE);
9594
var oldProblemChangesPerVersion =
96-
problemChangesPerVersionRef.getAndUpdate(map -> createNewProblemChangesMap(map.tailMap(boundaryVersion)));
95+
replaceMapSynchronized(map -> createNewProblemChangesMap(map.tailMap(boundaryVersion)));
9796
// At this point, the old map is not accessible to any other thread.
9897
// We also do not need to clear it, because this being the only reference,
9998
// garbage collector will do it for us.
@@ -102,29 +101,41 @@ BestSolutionContainingProblemChanges<Solution_> take() {
102101
.stream()
103102
.flatMap(Collection::stream)
104103
.toList();
105-
return new BestSolutionContainingProblemChanges<>(versionedBestSolution.bestSolution(), containedProblemChanges);
104+
return new BestSolutionContainingProblemChanges<>(latestVersionedBestSolution.bestSolution(), containedProblemChanges);
105+
}
106+
107+
private synchronized @Nullable VersionedBestSolution<Solution_> resetVersionedBestSolution() {
108+
var oldVersionedBestSolution = this.versionedBestSolution;
109+
this.versionedBestSolution = null;
110+
return oldVersionedBestSolution;
111+
}
112+
113+
private synchronized SortedMap<BigInteger, List<CompletableFuture<Void>>> replaceMapSynchronized(
114+
UnaryOperator<SortedMap<BigInteger, List<CompletableFuture<Void>>>> replaceFunction) {
115+
var oldMap = problemChangesPerVersionMap;
116+
problemChangesPerVersionMap = replaceFunction.apply(oldMap);
117+
return oldMap;
106118
}
107119

108120
/**
109-
* Sets the new best solution if all known problem changes have been processed and thus are contained in this
110-
* best solution.
121+
* Sets the new best solution if all known problem changes have been processed
122+
* and thus are contained in this best solution.
111123
*
112124
* @param bestSolution the new best solution that replaces the previous one if there is any
113125
* @param isEveryProblemChangeProcessed a supplier that tells if all problem changes have been processed
114126
*/
115127
void set(Solution_ bestSolution, BooleanSupplier isEveryProblemChangeProcessed) {
116-
/*
117-
* The new best solution can be accepted only if there are no pending problem changes
118-
* nor any additional changes may come during this operation.
119-
* Otherwise, a race condition might occur
120-
* that leads to associating problem changes with a solution that was created later,
121-
* but does not contain them yet.
122-
* As a result, CompletableFutures representing these changes would be completed too early.
123-
*/
128+
// The new best solution can be accepted only if there are no pending problem changes
129+
// nor any additional changes may come during this operation.
130+
// Otherwise, a race condition might occur
131+
// that leads to associating problem changes with a solution that was created later,
132+
// but does not contain them yet.
133+
// As a result, CompletableFutures representing these changes would be completed too early.
124134
if (isEveryProblemChangeProcessed.getAsBoolean()) {
125-
// This field is atomic, so we can safely set the new best solution without synchronization.
126-
versionedBestSolutionRef.set(
127-
new VersionedBestSolution<>(bestSolution, currentVersion.getAndUpdate(old -> old.add(BigInteger.ONE))));
135+
synchronized (this) {
136+
versionedBestSolution = new VersionedBestSolution<>(bestSolution, currentVersion);
137+
currentVersion = currentVersion.add(BigInteger.ONE);
138+
}
128139
}
129140
}
130141

@@ -139,23 +150,20 @@ void set(Solution_ bestSolution, BooleanSupplier isEveryProblemChangeProcessed)
139150
CompletableFuture<Void> addProblemChange(Solver<Solution_> solver, List<ProblemChange<Solution_>> problemChangeList) {
140151
var futureProblemChange = new CompletableFuture<Void>();
141152
synchronized (this) {
142-
// This actually needs to be synchronized,
143-
// as we want the new problem change and its version to be linked.
144-
var futureProblemChangeList =
145-
problemChangesPerVersionRef.get().computeIfAbsent(currentVersion.get(), version -> new ArrayList<>());
153+
var futureProblemChangeList = problemChangesPerVersionMap.computeIfAbsent(currentVersion,
154+
version -> new ArrayList<>());
146155
futureProblemChangeList.add(futureProblemChange);
147156
solver.addProblemChanges(problemChangeList);
148157
}
149158
return futureProblemChange;
150159
}
151160

152161
void cancelPendingChanges() {
153-
// The map is an atomic reference.
154-
// We first replace the reference with a new map atomically, avoiding synchronization issues.
155-
// Then we process the old map, which is safe because no one can access it anymore.
162+
// We first replace the reference with a new map, fully synchronized.
163+
// Then we process the old map unsynchronized, which is safe because no one can access it anymore.
156164
// We do not need to clear it, because this being the only reference,
157165
// the garbage collector will do it for us.
158-
problemChangesPerVersionRef.getAndSet(createNewProblemChangesMap())
166+
replaceMapSynchronized(map -> createNewProblemChangesMap())
159167
.values()
160168
.stream()
161169
.flatMap(Collection::stream)

core/src/test/java/ai/timefold/solver/core/impl/solver/BestSolutionHolderTest.java

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
11
package ai.timefold.solver.core.impl.solver;
22

33
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.awaitility.Awaitility.await;
45
import static org.mockito.Mockito.mock;
56
import static org.mockito.Mockito.times;
67
import static org.mockito.Mockito.verify;
78

9+
import java.time.Duration;
10+
import java.util.ArrayList;
811
import java.util.List;
12+
import java.util.Random;
13+
import java.util.UUID;
914
import java.util.concurrent.CompletableFuture;
15+
import java.util.concurrent.CountDownLatch;
16+
import java.util.concurrent.Executors;
1017

1118
import ai.timefold.solver.core.api.solver.Solver;
19+
import ai.timefold.solver.core.api.solver.SolverManager;
1220
import ai.timefold.solver.core.api.solver.change.ProblemChange;
21+
import ai.timefold.solver.core.api.solver.change.ProblemChangeDirector;
22+
import ai.timefold.solver.core.config.solver.SolverConfig;
23+
import ai.timefold.solver.core.config.solver.SolverManagerConfig;
24+
import ai.timefold.solver.core.impl.testdata.domain.TestdataEasyScoreCalculator;
25+
import ai.timefold.solver.core.impl.testdata.domain.TestdataEntity;
1326
import ai.timefold.solver.core.impl.testdata.domain.TestdataSolution;
1427

28+
import org.jspecify.annotations.NullMarked;
29+
import org.junit.jupiter.api.RepeatedTest;
1530
import org.junit.jupiter.api.Test;
1631
import org.mockito.Mockito;
1732

@@ -85,4 +100,125 @@ private CompletableFuture<Void> addProblemChange(BestSolutionHolder<TestdataSolu
85100
Mockito.argThat(problemChanges -> problemChanges.size() == 1 && problemChanges.get(0) == problemChange));
86101
return futureChange;
87102
}
103+
104+
@RepeatedTest(value = 10, failureThreshold = 1) // Run it multiple times to increase the chance of catching a concurrency issue.
105+
void problemChangeBarrageIntermediateBestSolutionConsumer() throws InterruptedException {
106+
var solverConfig = new SolverConfig()
107+
.withSolutionClass(TestdataSolution.class)
108+
.withEntityClasses(TestdataEntity.class)
109+
.withEasyScoreCalculatorClass(TestdataEasyScoreCalculator.class);
110+
111+
var futureList = new ArrayList<RecordedFuture>();
112+
var executorService = Executors.newFixedThreadPool(2);
113+
try (var solverManager = SolverManager.<TestdataSolution, UUID> create(solverConfig, new SolverManagerConfig())) {
114+
var solverStartedLatch = new CountDownLatch(1);
115+
var solution = TestdataSolution.generateSolution();
116+
var solverJob = solverManager.solveBuilder()
117+
.withProblemId(UUID.randomUUID())
118+
.withProblem(solution)
119+
.withFirstInitializedSolutionConsumer((testdataSolution, isTerminatedEarly) -> {
120+
solverStartedLatch.countDown();
121+
})
122+
.withBestSolutionConsumer(testdataSolution -> {
123+
// No need to do anything.
124+
})
125+
.run();
126+
solverStartedLatch.await(); // Only start adding problem changes after CH finished.
127+
128+
var random = new Random(0);
129+
var problemChangeCount = 200; // Arbitrary, for a reasonable test duration.
130+
var problemChangesAddedLatch = new CountDownLatch(problemChangeCount);
131+
for (int i = 0; i < problemChangeCount; i++) {
132+
// Emulate a random delay between problem changes, as it would happen in real world.
133+
var randomDelayNanos = random.nextInt(1_000_000);
134+
var start = System.nanoTime();
135+
while ((System.nanoTime() - randomDelayNanos) < start) {
136+
Thread.onSpinWait();
137+
}
138+
// Submit the problem change and store the future.
139+
var problemChange = random.nextBoolean()
140+
? new EntityAddingProblemChange(problemChangesAddedLatch)
141+
: new EntityRemovingProblemChange(problemChangesAddedLatch);
142+
futureList.add(new RecordedFuture(i, solverJob.addProblemChange(problemChange)));
143+
}
144+
// All problem changes have been added.
145+
// Does not guarantee all have been processed though.
146+
problemChangesAddedLatch.await();
147+
148+
// A best solution should have been produced for all the processed changes.
149+
// Any incomplete futures here means some problem change was "lost".
150+
var lostFutureList = futureList.stream()
151+
.filter(future -> !future.isDone())
152+
.toList();
153+
var lostFutureCount = lostFutureList.size();
154+
if (lostFutureCount == 0) {
155+
return;
156+
}
157+
// The only exception to the rule:
158+
// the very last problem changes, which might not have been processed yet
159+
// by the time the solver was forced to terminate.
160+
var minIncompleteFutureId = lostFutureList.stream()
161+
.mapToInt(f -> f.id)
162+
.min()
163+
.orElseThrow(() -> new AssertionError("Impossible state: no incomplete future found."));
164+
assertThat(minIncompleteFutureId).isEqualTo(problemChangeCount - lostFutureCount);
165+
} finally {
166+
executorService.shutdownNow();
167+
// The solver is terminated.
168+
// All incomplete futures should have been canceled.
169+
var incompleteFutureList = futureList.stream()
170+
.filter(future -> {
171+
await().atMost(Duration.ofSeconds(1))
172+
.pollInterval(Duration.ofMillis(1))
173+
.until(future::isDone);
174+
return !future.isDone();
175+
})
176+
.toList();
177+
assertThat(incompleteFutureList)
178+
.as("All futures should have been completed by now.")
179+
.isEmpty();
180+
}
181+
182+
}
183+
184+
private record RecordedFuture(int id, CompletableFuture<Void> future) {
185+
186+
boolean isDone() {
187+
return future.isDone();
188+
}
189+
190+
}
191+
192+
@NullMarked
193+
private record EntityAddingProblemChange(CountDownLatch latch) implements ProblemChange<TestdataSolution> {
194+
195+
@Override
196+
public void doChange(TestdataSolution workingSolution, ProblemChangeDirector problemChangeDirector) {
197+
var entity = new TestdataEntity(UUID.randomUUID().toString());
198+
problemChangeDirector.addEntity(entity,
199+
e -> workingSolution.getEntityList().add(e));
200+
problemChangeDirector.updateShadowVariables();
201+
latch.countDown();
202+
}
203+
204+
}
205+
206+
@NullMarked
207+
private record EntityRemovingProblemChange(CountDownLatch latch) implements ProblemChange<TestdataSolution> {
208+
209+
@Override
210+
public void doChange(TestdataSolution workingSolution, ProblemChangeDirector problemChangeDirector) {
211+
if (workingSolution.getEntityList().size() < 2) {
212+
latch.countDown();
213+
return;
214+
}
215+
var entity = workingSolution.getEntityList().get(0);
216+
problemChangeDirector.removeEntity(entity,
217+
e -> workingSolution.getEntityList().remove(e));
218+
problemChangeDirector.updateShadowVariables();
219+
latch.countDown();
220+
}
221+
222+
}
223+
88224
}

0 commit comments

Comments
 (0)