Skip to content

Commit b607542

Browse files
committed
chore: move streams dataset automatically contains null if it needs to
1 parent 56340ed commit b607542

File tree

7 files changed

+128
-19
lines changed

7 files changed

+128
-19
lines changed

core/src/main/java/ai/timefold/solver/core/impl/move/streams/DefaultMoveStreamFactory.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import ai.timefold.solver.core.impl.move.streams.maybeapi.stream.UniMoveStream;
1313
import ai.timefold.solver.core.preview.api.domain.metamodel.GenuineVariableMetaModel;
1414
import ai.timefold.solver.core.preview.api.domain.metamodel.PlanningEntityMetaModel;
15+
import ai.timefold.solver.core.preview.api.domain.metamodel.PlanningVariableMetaModel;
1516

1617
import org.jspecify.annotations.NullMarked;
1718

@@ -43,6 +44,17 @@ public <A> UniDataStream<Solution_, A> enumerate(Class<A> clz) {
4344
return enumerate(entityMetaModel.type());
4445
}
4546

47+
public <Entity_> UniDataStream<Solution_, Entity_>
48+
enumerateEntities(PlanningVariableMetaModel<Solution_, Entity_, ?> variableMetaModel) {
49+
return enumerateEntities(variableMetaModel.entity());
50+
}
51+
52+
/**
53+
* Enumerate possible values for a given variable.
54+
* If the variable allows unassigned values, the resulting stream will include a null value.
55+
*
56+
* @return data stream with all possible values of a given variable
57+
*/
4658
public <Entity_, A> UniDataStream<Solution_, A>
4759
enumeratePossibleValues(GenuineVariableMetaModel<Solution_, Entity_, A> variableMetaModel) {
4860
var variableDescriptor = getVariableDescriptor(variableMetaModel);

core/src/main/java/ai/timefold/solver/core/impl/move/streams/dataset/AbstractUniDataStream.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,6 @@ private <B> UniDataStream<Solution_, A> ifExistsOrNot(boolean shouldExist, UniDa
7575
joinerComber.getMergedJoiner(), joinerComber.getMergedFiltering()), childStreamList::add);
7676
}
7777

78-
@Override
79-
public UniDataStream<Solution_, A> addNull() {
80-
throw new UnsupportedOperationException();
81-
}
82-
8378
public UniDataset<Solution_, A> createDataset() {
8479
var stream = shareAndAddChild(new TerminalUniDataStream<>(dataStreamFactory, this));
8580
return stream.getDataset();

core/src/main/java/ai/timefold/solver/core/impl/move/streams/generic/provider/ChangeMoveProvider.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@ public MoveProducer<Solution_> apply(MoveStreamFactory<Solution_> moveStreamFact
3737
var defaultMoveStreamFactory = (DefaultMoveStreamFactory<Solution_>) moveStreamFactory;
3838
var valueStream = defaultMoveStreamFactory.enumeratePossibleValues(variableMetaModel)
3939
.filter(this::acceptValue);
40-
if (variableMetaModel.allowsUnassigned()) {
41-
valueStream = valueStream.addNull();
42-
}
43-
return moveStreamFactory.pick(defaultMoveStreamFactory.enumerateEntities(variableMetaModel.entity())
44-
.filter(this::acceptEntity))
40+
var entityStream = defaultMoveStreamFactory.enumerateEntities(variableMetaModel)
41+
.filter(this::acceptEntity);
42+
return moveStreamFactory.pick(entityStream)
4543
.pick(valueStream, this::acceptEntityValuePair)
4644
.asMove((solution, entity, value) -> new ChangeMove<>(variableMetaModel, entity, value));
4745
}

core/src/main/java/ai/timefold/solver/core/impl/move/streams/maybeapi/stream/UniDataStream.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,4 @@ default UniDataStream<Solution_, A> ifNotExistsOtherIncludingUnassigned(Class<A>
160160
return ifNotExistsIncludingUnassigned(otherClass, allJoiners);
161161
}
162162

163-
UniDataStream<Solution_, A> addNull();
164-
165163
}

core/src/test/java/ai/timefold/solver/core/impl/move/streams/dataset/UniDatasetStreamTest.java

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import static org.assertj.core.api.Assertions.assertThat;
44

5+
import ai.timefold.solver.core.impl.testdata.domain.TestdataEntity;
6+
import ai.timefold.solver.core.impl.testdata.domain.TestdataSolution;
57
import ai.timefold.solver.core.impl.testdata.domain.list.TestdataListEntity;
68
import ai.timefold.solver.core.impl.testdata.domain.list.TestdataListSolution;
79

@@ -10,7 +12,38 @@
1012
class UniDatasetStreamTest {
1113

1214
@Test
13-
void forEach() {
15+
void forEachBasicVariable() {
16+
var dataStreamFactory = new DataStreamFactory<>(TestdataSolution.buildSolutionDescriptor());
17+
var uniDataset = ((AbstractUniDataStream<TestdataSolution, TestdataEntity>) dataStreamFactory
18+
.forEach(TestdataEntity.class))
19+
.createDataset();
20+
21+
var solution = TestdataSolution.generateSolution(2, 2);
22+
var datasetSession = UniDatasetStreamTest.createSession(dataStreamFactory, solution);
23+
var uniDatasetInstance = datasetSession.getInstance(uniDataset);
24+
25+
var entity1 = solution.getEntityList().get(0);
26+
var entity2 = solution.getEntityList().get(1);
27+
28+
assertThat(uniDatasetInstance.iterator())
29+
.toIterable()
30+
.map(t -> t.factA)
31+
.containsExactly(entity1, entity2);
32+
33+
// Make incremental changes.
34+
var entity3 = new TestdataEntity("entity3", solution.getValueList().get(0));
35+
datasetSession.insert(entity3);
36+
datasetSession.retract(entity2);
37+
datasetSession.settle();
38+
39+
assertThat(uniDatasetInstance.iterator())
40+
.toIterable()
41+
.map(t -> t.factA)
42+
.containsExactly(entity1, entity3);
43+
}
44+
45+
@Test
46+
void forEachListVariable() {
1447
var dataStreamFactory = new DataStreamFactory<>(TestdataListSolution.buildSolutionDescriptor());
1548
var uniDataset = ((AbstractUniDataStream<TestdataListSolution, TestdataListEntity>) dataStreamFactory
1649
.forEach(TestdataListEntity.class))
@@ -40,8 +73,8 @@ void forEach() {
4073
.containsExactly(entity1, entity3);
4174
}
4275

43-
private DatasetSession<TestdataListSolution> createSession(DataStreamFactory<TestdataListSolution> dataStreamFactory,
44-
TestdataListSolution solution) {
76+
private static <Solution_> DatasetSession<Solution_> createSession(DataStreamFactory<Solution_> dataStreamFactory,
77+
Solution_ solution) {
4578
var datasetSessionFactory = new DatasetSessionFactory<>(dataStreamFactory);
4679
var datasetSession = datasetSessionFactory.buildSession();
4780
datasetSession.initialize(solution);

core/src/test/java/ai/timefold/solver/core/impl/move/streams/maybeapi/provider/ChangeMoveProviderTest.java

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
import ai.timefold.solver.core.impl.testdata.domain.TestdataEntity;
1414
import ai.timefold.solver.core.impl.testdata.domain.TestdataSolution;
1515
import ai.timefold.solver.core.impl.testdata.domain.TestdataValue;
16+
import ai.timefold.solver.core.impl.testdata.domain.allows_unassigned.TestdataAllowsUnassignedEntity;
17+
import ai.timefold.solver.core.impl.testdata.domain.allows_unassigned.TestdataAllowsUnassignedSolution;
1618

1719
import org.junit.jupiter.api.Test;
1820

1921
class ChangeMoveProviderTest {
2022

2123
@Test
22-
void fromSolution() {
24+
void fromSolutionBasicVariable() {
2325
var solutionDescriptor = TestdataSolution.buildSolutionDescriptor();
2426
var variableMetaModel = solutionDescriptor.getMetaModel()
2527
.entity(TestdataEntity.class)
@@ -79,8 +81,71 @@ void fromSolution() {
7981
});
8082
}
8183

82-
private MoveStreamSession<TestdataSolution> createSession(DefaultMoveStreamFactory<TestdataSolution> moveStreamFactory,
83-
SolutionDescriptor<TestdataSolution> solutionDescriptor, TestdataSolution solution) {
84+
@Test
85+
void fromSolutionBasicVariableAllowsUnassigned() {
86+
var solutionDescriptor = TestdataAllowsUnassignedSolution.buildSolutionDescriptor();
87+
var variableMetaModel = solutionDescriptor.getMetaModel()
88+
.entity(TestdataAllowsUnassignedEntity.class)
89+
.genuineVariable()
90+
.ensurePlanningVariable();
91+
var moveStreamFactory = new DefaultMoveStreamFactory<>(solutionDescriptor);
92+
var moveProvider = new ChangeMoveProvider<>(variableMetaModel);
93+
var moveProducer = moveProvider.apply(moveStreamFactory);
94+
95+
var solution = TestdataAllowsUnassignedSolution.generateSolution(2, 2);
96+
var firstEntity = solution.getEntityList().get(0); // Assigned to null.
97+
var secondEntity = solution.getEntityList().get(1); // Assigned to secondValue.
98+
var firstValue = solution.getValueList().get(0); // Not assigned to any entity.
99+
var secondValue = solution.getValueList().get(1);
100+
var moveStreamSession = createSession(moveStreamFactory, solutionDescriptor, solution);
101+
102+
// Filters out moves that would change the value to the value the entity already has.
103+
// Therefore this will have 4 moves (2 entities * 2 values) as opposed to 6 (2 entities * 3 values).
104+
var moveIterable = moveProducer.getMoveIterable(moveStreamSession);
105+
assertThat(moveIterable).hasSize(4);
106+
107+
var moveList = StreamSupport.stream(moveIterable.spliterator(), false)
108+
.map(m -> (ChangeMove<TestdataAllowsUnassignedSolution, TestdataAllowsUnassignedEntity, TestdataValue>) m)
109+
.toList();
110+
assertThat(moveList).hasSize(4);
111+
112+
// First entity is assigned to null, therefore the applicable moves assign to firstValue and secondValue.
113+
var firstMove = moveList.get(0);
114+
assertSoftly(softly -> {
115+
softly.assertThat(firstMove.extractPlanningEntities())
116+
.containsExactly(firstEntity);
117+
softly.assertThat(firstMove.extractPlanningValues())
118+
.containsExactly(firstValue);
119+
});
120+
121+
var secondMove = moveList.get(1);
122+
assertSoftly(softly -> {
123+
softly.assertThat(secondMove.extractPlanningEntities())
124+
.containsExactly(firstEntity);
125+
softly.assertThat(secondMove.extractPlanningValues())
126+
.containsExactly(secondValue);
127+
});
128+
129+
// Second entity is assigned to secondValue, therefore the applicable moves assign to null and firstValue.
130+
var thirdMove = moveList.get(2);
131+
assertSoftly(softly -> {
132+
softly.assertThat(thirdMove.extractPlanningEntities())
133+
.containsExactly(secondEntity);
134+
softly.assertThat(thirdMove.extractPlanningValues())
135+
.containsExactly(new TestdataValue[] { null });
136+
});
137+
138+
var fourthMove = moveList.get(3);
139+
assertSoftly(softly -> {
140+
softly.assertThat(fourthMove.extractPlanningEntities())
141+
.containsExactly(secondEntity);
142+
softly.assertThat(fourthMove.extractPlanningValues())
143+
.containsExactly(firstValue);
144+
});
145+
}
146+
147+
private <Solution_> MoveStreamSession<Solution_> createSession(DefaultMoveStreamFactory<Solution_> moveStreamFactory,
148+
SolutionDescriptor<Solution_> solutionDescriptor, Solution_ solution) {
84149
var moveStreamSession = moveStreamFactory.createSession(solution);
85150
solutionDescriptor.visitAll(solution, moveStreamSession::insert);
86151
moveStreamSession.settle();

core/src/test/java/ai/timefold/solver/core/impl/testdata/domain/TestdataSolution.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ public static TestdataSolution generateSolution() {
2323
}
2424

2525
public static TestdataSolution generateSolution(int valueListSize, int entityListSize) {
26+
return generateSolution(valueListSize, entityListSize, true);
27+
}
28+
29+
public static TestdataSolution generateUninitializedSolution(int valueListSize, int entityListSize) {
30+
return generateSolution(valueListSize, entityListSize, false);
31+
}
32+
33+
private static TestdataSolution generateSolution(int valueListSize, int entityListSize, boolean initialized) {
2634
TestdataSolution solution = new TestdataSolution("Generated Solution 0");
2735
List<TestdataValue> valueList = new ArrayList<>(valueListSize);
2836
for (int i = 0; i < valueListSize; i++) {
@@ -32,7 +40,7 @@ public static TestdataSolution generateSolution(int valueListSize, int entityLis
3240
solution.setValueList(valueList);
3341
List<TestdataEntity> entityList = new ArrayList<>(entityListSize);
3442
for (int i = 0; i < entityListSize; i++) {
35-
TestdataValue value = valueList.get(i % valueListSize);
43+
TestdataValue value = initialized ? valueList.get(i % valueListSize) : null;
3644
TestdataEntity entity = new TestdataEntity("Generated Entity " + i, value);
3745
entityList.add(entity);
3846
}

0 commit comments

Comments
 (0)