Skip to content

Commit 790a10d

Browse files
authored
perf: improve Nearby logic for entity ranges (#1729)
Companion for the PR in the enterprise repo.
1 parent de5f86f commit 790a10d

File tree

11 files changed

+316
-68
lines changed

11 files changed

+316
-68
lines changed

core/src/main/java/ai/timefold/solver/core/impl/heuristic/selector/entity/decorator/FilteringEntityValueRangeSelector.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public void phaseStarted(AbstractPhaseScope<Solution_> phaseScope) {
6262
super.phaseStarted(phaseScope);
6363
this.entitiesSize = childEntitySelector.getEntityDescriptor().extractEntities(phaseScope.getWorkingSolution()).size();
6464
this.reachableValues = phaseScope.getScoreDirector().getValueRangeManager()
65-
.getReachableValeMatrix(childEntitySelector.getEntityDescriptor().getGenuineListVariableDescriptor());
65+
.getReachableValues(phaseScope.getScoreDirector().getSolutionDescriptor().getListVariableDescriptor());
6666
this.childEntitySelector.phaseStarted(phaseScope);
6767
}
6868

core/src/main/java/ai/timefold/solver/core/impl/heuristic/selector/value/decorator/FilteringValueRangeSelector.java

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void phaseStarted(AbstractPhaseScope<Solution_> phaseScope) {
7676
this.nonReplayingValueSelector.phaseStarted(phaseScope);
7777
this.replayingValueSelector.phaseStarted(phaseScope);
7878
this.reachableValues = phaseScope.getScoreDirector().getValueRangeManager()
79-
.getReachableValeMatrix(listVariableStateSupply.getSourceVariableDescriptor());
79+
.getReachableValues(listVariableStateSupply.getSourceVariableDescriptor());
8080
valuesSize = reachableValues.getSize();
8181
}
8282

@@ -129,20 +129,31 @@ public Iterator<Object> iterator(Object entity) {
129129
@Override
130130
public Iterator<Object> iterator() {
131131
if (randomSelection) {
132-
return new RandomFilteringValueRangeIterator(replayingValueSelector.iterator(), listVariableStateSupply,
133-
reachableValues, workingRandom, (int) getSize(), checkSourceAndDestination, true);
132+
// If the nonReplayingValueSelector does not have any additional configuration,
133+
// we can bypass it and only use reachable values,
134+
// which helps optimize the number of evaluations.
135+
// However, if the nonReplayingValueSelector includes custom configurations,
136+
// such as filtering,
137+
// we will first evaluate its values and then filter out those that are not reachable.
138+
if (nonReplayingValueSelector instanceof IterableFromEntityPropertyValueSelector<Solution_>) {
139+
return new OptimizedRandomFilteringValueRangeIterator(replayingValueSelector.iterator(),
140+
listVariableStateSupply,
141+
reachableValues, workingRandom, (int) getSize(), checkSourceAndDestination);
142+
} else {
143+
return new RandomFilteringValueRangeIterator(replayingValueSelector.iterator(),
144+
nonReplayingValueSelector.iterator(), listVariableStateSupply, reachableValues, (int) getSize(),
145+
checkSourceAndDestination);
146+
}
134147
} else {
135148
return new OriginalFilteringValueRangeIterator(replayingValueSelector.iterator(),
136-
nonReplayingValueSelector.iterator(), listVariableStateSupply, reachableValues, checkSourceAndDestination,
137-
false);
149+
nonReplayingValueSelector.iterator(), listVariableStateSupply, reachableValues, checkSourceAndDestination);
138150
}
139151
}
140152

141153
@Override
142154
public Iterator<Object> endingIterator(Object entity) {
143155
return new OriginalFilteringValueRangeIterator(replayingValueSelector.iterator(),
144-
nonReplayingValueSelector.iterator(), listVariableStateSupply, reachableValues, checkSourceAndDestination,
145-
false);
156+
nonReplayingValueSelector.iterator(), listVariableStateSupply, reachableValues, checkSourceAndDestination);
146157
}
147158

148159
@Override
@@ -255,25 +266,21 @@ boolean isValueOrEntityReachable(Object destinationValue) {
255266
}
256267
}
257268

258-
private class OriginalFilteringValueRangeIterator extends AbstractFilteringValueRangeIterator {
269+
private abstract class AbstractUpcomingValueRangeIterator extends AbstractFilteringValueRangeIterator {
259270
// The value iterator that only replays the current selected value
260-
private final Iterator<Object> replayingValueIterator;
261-
// The value iterator returns all possible values based on its settings.
262-
// However,
263-
// it may include invalid values that need to be filtered out.
264-
// This iterator must be used to ensure that all positions are included in the CH phase.
265-
// This does not apply to the LS phase.
266-
private final Iterator<Object> valueIterator;
271+
final Iterator<Object> replayingValueIterator;
272+
// The value iterator returns all possible values based on the outer selector settings.
273+
final Iterator<Object> valueIterator;
267274

268-
private OriginalFilteringValueRangeIterator(Iterator<Object> replayingValueIterator, Iterator<Object> valueIterator,
275+
private AbstractUpcomingValueRangeIterator(Iterator<Object> replayingValueIterator, Iterator<Object> valueIterator,
269276
ListVariableStateSupply<Solution_> listVariableStateSupply, ReachableValues reachableValues,
270277
boolean checkSourceAndDestination, boolean useValueList) {
271278
super(listVariableStateSupply, reachableValues, checkSourceAndDestination, useValueList);
272279
this.replayingValueIterator = replayingValueIterator;
273280
this.valueIterator = valueIterator;
274281
}
275282

276-
private void initialize() {
283+
void initialize() {
277284
if (initialized) {
278285
return;
279286
}
@@ -288,6 +295,16 @@ private void initialize() {
288295
noData();
289296
}
290297
}
298+
}
299+
300+
private class OriginalFilteringValueRangeIterator extends AbstractUpcomingValueRangeIterator {
301+
302+
private OriginalFilteringValueRangeIterator(Iterator<Object> replayingValueIterator, Iterator<Object> valueIterator,
303+
ListVariableStateSupply<Solution_> listVariableStateSupply, ReachableValues reachableValues,
304+
boolean checkSourceAndDestination) {
305+
super(replayingValueIterator, valueIterator, listVariableStateSupply, reachableValues, checkSourceAndDestination,
306+
false);
307+
}
291308

292309
@Override
293310
protected Object createUpcomingSelection() {
@@ -306,16 +323,51 @@ protected Object createUpcomingSelection() {
306323
}
307324
}
308325

309-
private class RandomFilteringValueRangeIterator extends AbstractFilteringValueRangeIterator {
326+
private class RandomFilteringValueRangeIterator extends AbstractUpcomingValueRangeIterator {
327+
private final int maxBailoutSize;
328+
329+
private RandomFilteringValueRangeIterator(Iterator<Object> replayingValueIterator, Iterator<Object> valueIterator,
330+
ListVariableStateSupply<Solution_> listVariableStateSupply, ReachableValues reachableValues,
331+
int maxBailoutSize, boolean checkSourceAndDestination) {
332+
super(replayingValueIterator, valueIterator, listVariableStateSupply, reachableValues, checkSourceAndDestination,
333+
false);
334+
this.maxBailoutSize = maxBailoutSize;
335+
}
336+
337+
@Override
338+
protected Object createUpcomingSelection() {
339+
initialize();
340+
if (!hasData) {
341+
return noUpcomingSelection();
342+
}
343+
Object next;
344+
var bailoutSize = maxBailoutSize;
345+
do {
346+
if (bailoutSize <= 0 || !valueIterator.hasNext()) {
347+
return noUpcomingSelection();
348+
}
349+
bailoutSize--;
350+
next = valueIterator.next();
351+
} while (!isValueOrEntityReachable(next));
352+
return next;
353+
}
354+
}
355+
356+
/**
357+
* The optimized iterator only traverses reachable values from the current selection.
358+
* Unlike {@link RandomFilteringValueRangeIterator},
359+
* it does not use an outer iterator to filter out non-reachable values.
360+
*/
361+
private class OptimizedRandomFilteringValueRangeIterator extends AbstractFilteringValueRangeIterator {
310362

311363
private final Iterator<Object> replayingValueIterator;
312364
private final Random workingRandom;
313365
private final int maxBailoutSize;
314366

315-
private RandomFilteringValueRangeIterator(Iterator<Object> replayingValueIterator,
367+
private OptimizedRandomFilteringValueRangeIterator(Iterator<Object> replayingValueIterator,
316368
ListVariableStateSupply<Solution_> listVariableStateSupply, ReachableValues reachableValues,
317-
Random workingRandom, int maxBailoutSize, boolean checkSourceAndDestination, boolean useValueList) {
318-
super(listVariableStateSupply, reachableValues, checkSourceAndDestination, useValueList);
369+
Random workingRandom, int maxBailoutSize, boolean checkSourceAndDestination) {
370+
super(listVariableStateSupply, reachableValues, checkSourceAndDestination, true);
319371
this.replayingValueIterator = replayingValueIterator;
320372
this.workingRandom = workingRandom;
321373
this.maxBailoutSize = maxBailoutSize;

core/src/main/java/ai/timefold/solver/core/impl/score/director/ValueRangeManager.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ public final class ValueRangeManager<Solution_> {
5959
private final Map<ValueRangeDescriptor<Solution_>, CountableValueRange<?>> fromSolutionMap = new IdentityHashMap<>();
6060
private final Map<Object, Map<ValueRangeDescriptor<Solution_>, CountableValueRange<?>>> fromEntityMap =
6161
new IdentityHashMap<>();
62-
6362
private @Nullable ReachableValues reachableValues = null;
63+
6464
private @Nullable Solution_ cachedWorkingSolution = null;
6565
private @Nullable SolutionInitializationStatistics cachedInitializationStatistics = null;
6666
private @Nullable ProblemSizeStatistics cachedProblemSizeStatistics = null;
@@ -414,32 +414,30 @@ public long countOnEntity(ValueRangeDescriptor<Solution_> valueRangeDescriptor,
414414
.getSize();
415415
}
416416

417-
public ReachableValues getReachableValeMatrix(ListVariableDescriptor<Solution_> listVariableDescriptor) {
417+
public ReachableValues getReachableValues(ListVariableDescriptor<Solution_> listVariableDescriptor) {
418418
if (reachableValues == null) {
419419
if (cachedWorkingSolution == null) {
420420
throw new IllegalStateException(
421-
"Impossible state: the matrix %s requested before the working solution is known."
422-
.formatted(ReachableValues.class.getSimpleName()));
421+
"Impossible state: value reachability requested before the working solution is known.");
423422
}
424423
var entityDescriptor = listVariableDescriptor.getEntityDescriptor();
425-
var valueRangeDescriptor = listVariableDescriptor.getValueRangeDescriptor();
426424
var entityList = entityDescriptor.extractEntities(cachedWorkingSolution);
427-
var allValues = getFromSolution(valueRangeDescriptor);
425+
var allValues = getFromSolution(listVariableDescriptor.getValueRangeDescriptor());
428426
var valuesSize = allValues.getSize();
429427
if (valuesSize > Integer.MAX_VALUE) {
430428
throw new IllegalStateException(
431429
"The matrix %s cannot be built for the entity %s (%s) because value range has a size (%d) which is higher than Integer.MAX_VALUE."
432430
.formatted(ReachableValues.class.getSimpleName(),
433431
entityDescriptor.getEntityClass().getSimpleName(),
434-
valueRangeDescriptor.getVariableDescriptor().getVariableName(), valuesSize));
432+
listVariableDescriptor.getVariableName(), valuesSize));
435433
}
436434
// list of entities reachable for a value
437435
var entityMatrix = new IdentityHashMap<Object, Set<Object>>((int) valuesSize);
438436
// list of values reachable for a value
439437
var valueMatrix = new IdentityHashMap<Object, Set<Object>>((int) valuesSize);
440438
for (var entity : entityList) {
441439
var valuesIterator = allValues.createOriginalIterator();
442-
var range = getFromEntity(valueRangeDescriptor, entity);
440+
var range = getFromEntity(listVariableDescriptor.getValueRangeDescriptor(), entity);
443441
while (valuesIterator.hasNext()) {
444442
var value = valuesIterator.next();
445443
if (range.contains(value)) {

core/src/test/java/ai/timefold/solver/core/impl/constructionheuristic/DefaultConstructionHeuristicPhaseTest.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ void solveWithEntityValueRangeBasicVariable() {
365365
@Test
366366
void solveWithEntityValueRangeListVariable() {
367367
var solverConfig = PlannerTestUtils
368-
.buildSolverConfig(TestdataListEntityProvidingSolution.class, TestdataListEntityProvidingEntity.class)
368+
.buildSolverConfig(TestdataListEntityProvidingSolution.class, TestdataListEntityProvidingEntity.class,
369+
TestdataListEntityProvidingValue.class)
369370
.withEasyScoreCalculatorClass(TestdataListEntityProvidingScoreCalculator.class)
370371
.withPhases(new ConstructionHeuristicPhaseConfig());
371372

@@ -381,8 +382,10 @@ void solveWithEntityValueRangeListVariable() {
381382
var bestSolution = PlannerTestUtils.solve(solverConfig, solution, true);
382383
assertThat(bestSolution).isNotNull();
383384
// Only one entity should provide the value list and assign the values.
384-
assertThat(bestSolution.getEntityList().get(0).getValueList()).hasSameElementsAs(List.of(value1, value2));
385-
assertThat(bestSolution.getEntityList().get(1).getValueList()).hasSameElementsAs(List.of(value3));
385+
assertThat(bestSolution.getEntityList().get(0).getValueList().stream().map(TestdataListEntityProvidingValue::getCode))
386+
.hasSameElementsAs(List.of("v1", "v2"));
387+
assertThat(bestSolution.getEntityList().get(1).getValueList().stream().map(TestdataListEntityProvidingValue::getCode))
388+
.hasSameElementsAs(List.of("v3"));
386389
}
387390

388391
@Test

core/src/test/java/ai/timefold/solver/core/impl/heuristic/selector/common/ReachableMatrixTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void testReachableEntities() {
3232
var solutionDescriptor = scoreDirector.getSolutionDescriptor();
3333
var entityDescriptor = solutionDescriptor.findEntityDescriptor(TestdataListEntityProvidingEntity.class);
3434
var reachableValues = scoreDirector.getValueRangeManager()
35-
.getReachableValeMatrix(entityDescriptor.getGenuineListVariableDescriptor());
35+
.getReachableValues(entityDescriptor.getGenuineListVariableDescriptor());
3636

3737
assertThat(reachableValues.extractEntities(v1)).containsExactlyInAnyOrder(a);
3838
assertThat(reachableValues.extractEntities(v2)).containsExactlyInAnyOrder(a, b);
@@ -60,7 +60,7 @@ void testReachableValues() {
6060
var solutionDescriptor = scoreDirector.getSolutionDescriptor();
6161
var entityDescriptor = solutionDescriptor.findEntityDescriptor(TestdataListEntityProvidingEntity.class);
6262
var reachableValues = scoreDirector.getValueRangeManager()
63-
.getReachableValeMatrix(entityDescriptor.getGenuineListVariableDescriptor());
63+
.getReachableValues(entityDescriptor.getGenuineListVariableDescriptor());
6464

6565
assertThat(reachableValues.extractValues(v1)).containsExactlyInAnyOrder(v2, v3);
6666
assertThat(reachableValues.extractValues(v2)).containsExactlyInAnyOrder(v1, v3);

0 commit comments

Comments
 (0)