Skip to content

Commit daae677

Browse files
authored
fix: entity filtering NPE (#1866)
1 parent afe99f5 commit daae677

File tree

2 files changed

+122
-5
lines changed

2 files changed

+122
-5
lines changed

core/src/main/java/ai/timefold/solver/core/impl/heuristic/selector/entity/decorator/FilteringEntityByEntitySelector.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ private Object pickNext() {
298298
this.selected = null;
299299
while (entityIterator.hasNext()) {
300300
var entity = entityIterator.next();
301-
if (isReachable(entity)) {
301+
if (entity != null && isReachable(entity)) {
302302
return entity;
303303
}
304304
}
@@ -364,7 +364,7 @@ protected Object createUpcomingSelection() {
364364
}
365365
while (entityIterator.hasNext()) {
366366
var otherEntity = entityIterator.next();
367-
if (isReachable(replayedEntity, otherEntity)) {
367+
if (otherEntity != null && isReachable(replayedEntity, otherEntity)) {
368368
return otherEntity;
369369
}
370370
}
@@ -378,7 +378,7 @@ protected Object createPreviousSelection() {
378378
}
379379
while (entityIterator.hasPrevious()) {
380380
var otherEntity = entityIterator.previous();
381-
if (isReachable(replayedEntity, otherEntity)) {
381+
if (otherEntity != null && isReachable(replayedEntity, otherEntity)) {
382382
return otherEntity;
383383
}
384384
}
@@ -442,10 +442,10 @@ public Object next() {
442442
bailoutSize--;
443443
// We expect the iterator to apply a random selection
444444
var next = entityIterator.next();
445-
if (isReachable(currentReplayedEntity, next)) {
445+
if (next != null && isReachable(currentReplayedEntity, next)) {
446446
return next;
447447
}
448-
} while (bailoutSize > 0);
448+
} while (bailoutSize > 0 && entityIterator.hasNext());
449449
// If no reachable entity is found, we return the currently selected entity,
450450
// which will result in a non-doable move
451451
return currentReplayedEntity;

core/src/test/java/ai/timefold/solver/core/impl/heuristic/selector/move/generic/SwapMoveSelectorTest.java

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package ai.timefold.solver.core.impl.heuristic.selector.move.generic;
22

3+
import static ai.timefold.solver.core.impl.heuristic.selector.SelectorTestUtils.mockEntitySelector;
34
import static ai.timefold.solver.core.impl.heuristic.selector.SelectorTestUtils.phaseStarted;
45
import static ai.timefold.solver.core.impl.heuristic.selector.SelectorTestUtils.solvingStarted;
56
import static ai.timefold.solver.core.testdomain.list.TestdataListUtils.getEntityDescriptor;
@@ -15,18 +16,23 @@
1516
import java.util.List;
1617
import java.util.Random;
1718

19+
import ai.timefold.solver.core.api.score.director.ScoreDirector;
1820
import ai.timefold.solver.core.config.heuristic.selector.common.SelectionCacheType;
1921
import ai.timefold.solver.core.impl.domain.entity.descriptor.EntityDescriptor;
2022
import ai.timefold.solver.core.impl.heuristic.selector.SelectorTestUtils;
23+
import ai.timefold.solver.core.impl.heuristic.selector.common.decorator.SelectionFilter;
2124
import ai.timefold.solver.core.impl.heuristic.selector.entity.EntitySelector;
2225
import ai.timefold.solver.core.impl.heuristic.selector.entity.FromSolutionEntitySelector;
2326
import ai.timefold.solver.core.impl.heuristic.selector.entity.decorator.FilteringEntityByEntitySelector;
27+
import ai.timefold.solver.core.impl.heuristic.selector.entity.decorator.FilteringEntitySelector;
28+
import ai.timefold.solver.core.impl.heuristic.selector.entity.mimic.ManualEntityMimicRecorder;
2429
import ai.timefold.solver.core.impl.heuristic.selector.entity.mimic.MimicRecordingEntitySelector;
2530
import ai.timefold.solver.core.impl.heuristic.selector.entity.mimic.MimicReplayingEntitySelector;
2631
import ai.timefold.solver.core.impl.phase.scope.AbstractPhaseScope;
2732
import ai.timefold.solver.core.impl.phase.scope.AbstractStepScope;
2833
import ai.timefold.solver.core.impl.solver.scope.SolverScope;
2934
import ai.timefold.solver.core.testdomain.TestdataEntity;
35+
import ai.timefold.solver.core.testdomain.TestdataObject;
3036
import ai.timefold.solver.core.testdomain.TestdataValue;
3137
import ai.timefold.solver.core.testdomain.valuerange.entityproviding.TestdataEntityProvidingEntity;
3238
import ai.timefold.solver.core.testdomain.valuerange.entityproviding.TestdataEntityProvidingSolution;
@@ -336,6 +342,103 @@ void emptyRightOriginalLeftUnequalsRight() {
336342
verifyPhaseLifecycle(rightEntitySelector, 1, 2, 5);
337343
}
338344

345+
@Test
346+
void originalEntitiesPinned() {
347+
var v1 = new TestdataValue("1");
348+
var v2 = new TestdataValue("2");
349+
var v3 = new TestdataValue("3");
350+
var v4 = new TestdataValue("4");
351+
var e1 = new TestdataAllowsUnassignedEntityProvidingEntity("A", List.of(v1, v4));
352+
var e2 = new TestdataAllowsUnassignedEntityProvidingEntity("B", List.of(v2, v3));
353+
var e3 = new TestdataAllowsUnassignedEntityProvidingEntity("C", List.of(v1, v4));
354+
var solution = new TestdataAllowsUnassignedEntityProvidingSolution("s1");
355+
solution.setEntityList(List.of(e1, e2, e3));
356+
357+
var scoreDirector = mockScoreDirector(TestdataAllowsUnassignedEntityProvidingSolution.buildSolutionDescriptor());
358+
scoreDirector.setWorkingSolution(solution);
359+
360+
var leftEntitySelector = new ManualEntityMimicRecorder<>(
361+
mockEntitySelector(TestdataAllowsUnassignedEntityProvidingEntity.buildEntityDescriptor(), e1, e2, e3));
362+
363+
var replayingEntitySelector = new MimicReplayingEntitySelector<>(leftEntitySelector);
364+
var filteringEntitySelector =
365+
FilteringEntitySelector.of(
366+
mockEntitySelector(TestdataAllowsUnassignedEntityProvidingEntity.buildEntityDescriptor(), e1, e2, e3),
367+
new EntityCodeFiltering<>(List.of("B", "C")));
368+
var rightEntitySelector =
369+
new FilteringEntityByEntitySelector<>(filteringEntitySelector, replayingEntitySelector, false);
370+
var solverScope = solvingStarted(rightEntitySelector, scoreDirector);
371+
phaseStarted(rightEntitySelector, solverScope);
372+
373+
// Regular iterator
374+
// The left selector chooses A, and the right selector returns no value
375+
leftEntitySelector.setRecordedEntity(e1);
376+
var iterator = rightEntitySelector.iterator();
377+
assertThat(iterator.hasNext()).isFalse();
378+
// The left selector chooses B, and the right selector returns A
379+
leftEntitySelector.setRecordedEntity(e2);
380+
iterator = rightEntitySelector.iterator();
381+
assertThat(iterator.hasNext()).isTrue();
382+
assertThat(iterator.next()).hasToString("A");
383+
// No more moves
384+
assertThat(iterator.hasNext()).isFalse();
385+
386+
// ListIterator
387+
// The left selector chooses A, and the right selector returns no value
388+
leftEntitySelector.setRecordedEntity(e1);
389+
var listIterator = rightEntitySelector.listIterator();
390+
assertThat(listIterator.hasNext()).isFalse();
391+
// B <-> A
392+
leftEntitySelector.setRecordedEntity(e2);
393+
listIterator = rightEntitySelector.listIterator();
394+
assertThat(listIterator.hasNext()).isTrue();
395+
assertThat(listIterator.next()).hasToString("A");
396+
assertThat(listIterator.hasNext()).isFalse();
397+
// Backward move
398+
assertThat(listIterator.hasPrevious()).isTrue();
399+
assertThat(listIterator.previous()).hasToString("A");
400+
}
401+
402+
@Test
403+
void randomEntitiesPinned() {
404+
var v1 = new TestdataValue("1");
405+
var v2 = new TestdataValue("2");
406+
var v3 = new TestdataValue("3");
407+
var v4 = new TestdataValue("4");
408+
var e1 = new TestdataAllowsUnassignedEntityProvidingEntity("A", List.of(v1, v4), v1);
409+
var e2 = new TestdataAllowsUnassignedEntityProvidingEntity("B", List.of(v2, v3), v2);
410+
var e3 = new TestdataAllowsUnassignedEntityProvidingEntity("C", List.of(v1, v4));
411+
var solution = new TestdataAllowsUnassignedEntityProvidingSolution("s1");
412+
solution.setEntityList(List.of(e1, e2, e3));
413+
414+
var scoreDirector = mockScoreDirector(TestdataAllowsUnassignedEntityProvidingSolution.buildSolutionDescriptor());
415+
scoreDirector.setWorkingSolution(solution);
416+
417+
var baseEntitySelector =
418+
new FromSolutionEntitySelector<>(getEntityDescriptor(scoreDirector), SelectionCacheType.JUST_IN_TIME, true);
419+
var leftEntitySelector = new ManualEntityMimicRecorder<>(baseEntitySelector);
420+
421+
var replayingEntitySelector = new MimicReplayingEntitySelector<>(leftEntitySelector);
422+
var filteringEntitySelector =
423+
FilteringEntitySelector.of(baseEntitySelector, new EntityCodeFiltering<>(List.of("B", "C")));
424+
var rightEntitySelector =
425+
new FilteringEntityByEntitySelector<>(filteringEntitySelector, replayingEntitySelector, true);
426+
var solverScope = solvingStarted(rightEntitySelector, scoreDirector,
427+
new TestRandom(0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1));
428+
phaseStarted(rightEntitySelector, solverScope);
429+
430+
// Random iterator
431+
// The left selector chooses B,
432+
// and the right selector chooses A (not reachable), and then only B (excluded by the filter)
433+
leftEntitySelector.setRecordedEntity(e2);
434+
var iterator = rightEntitySelector.iterator();
435+
assertThat(iterator.hasNext()).isTrue();
436+
// Return the same as the left selector
437+
assertThat(iterator.next()).hasToString("B");
438+
// No more moves
439+
assertThat(iterator.hasNext()).isFalse();
440+
}
441+
339442
@Test
340443
void singleVarRandomSelectionWithEntityValueRange() {
341444
var v1 = new TestdataValue("1");
@@ -492,4 +595,18 @@ void noReachableEntities() {
492595
var swapMove = (SwapMove<TestdataEntityProvidingSolution>) iterator.next();
493596
assertThat(swapMove.getLeftEntity()).isSameAs(swapMove.getRightEntity());
494597
}
598+
599+
private static class EntityCodeFiltering<Solution_> implements SelectionFilter<Solution_, Object> {
600+
601+
private final List<String> excludedCodes;
602+
603+
public EntityCodeFiltering(List<String> excludedCodes) {
604+
this.excludedCodes = excludedCodes;
605+
}
606+
607+
@Override
608+
public boolean accept(ScoreDirector<Solution_> scoreDirector, Object selection) {
609+
return !excludedCodes.contains(((TestdataObject) selection).getCode());
610+
}
611+
}
495612
}

0 commit comments

Comments
 (0)