Skip to content

Commit 5936b3d

Browse files
feat: add forEachUnfilteredUniquePair to PrecomputeFactory (#1898)
1 parent 21d8b8b commit 5936b3d

File tree

5 files changed

+302
-4
lines changed

5 files changed

+302
-4
lines changed

core/src/main/java/ai/timefold/solver/core/api/score/stream/PrecomputeFactory.java

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import java.util.function.Function;
44

55
import ai.timefold.solver.core.api.domain.entity.PlanningEntity;
6+
import ai.timefold.solver.core.api.score.stream.bi.BiConstraintStream;
7+
import ai.timefold.solver.core.api.score.stream.bi.BiJoiner;
68
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream;
79

810
/**
@@ -41,4 +43,123 @@ public interface PrecomputeFactory {
4143
* @param <A> the type of the matched problem fact or {@link PlanningEntity planning entity}
4244
*/
4345
<A> UniConstraintStream<A> forEachUnfiltered(Class<A> sourceClass);
46+
47+
/**
48+
* As defined by {@link ConstraintFactory#forEachUniquePair(Class)},
49+
* with the additional change that the problem facts/entities are unfiltered.
50+
* <p>
51+
* For example,
52+
* <p>
53+
*
54+
* <pre>
55+
* precomputeFactory.forEachUnfilteredUniquePair(Shift.class);
56+
* </pre>
57+
* <p>
58+
* Would roughly be equivalent to
59+
* <p>
60+
*
61+
* <pre>
62+
* constraintFactory.forEachUnfiltered(Shift.class)
63+
* .join(constraintFactory.forEachUnfiltered(Shift.class),
64+
* Joiners.lessThan(Shift::getId));
65+
* </pre>
66+
* <p>
67+
* Important: no variables can be referenced in any operations performed
68+
* by the returned {@link ConstraintStream}, otherwise a score corruption will
69+
* occur.
70+
* See the note in {@link ConstraintFactory#precompute(Function)} for
71+
* more details.
72+
*
73+
* @param <A> the type of the matched problem fact or {@link PlanningEntity planning entity}
74+
*/
75+
@SuppressWarnings("unchecked")
76+
default <A> BiConstraintStream<A, A> forEachUnfilteredUniquePair(Class<A> sourceClass) {
77+
return forEachUnfilteredUniquePair(sourceClass, new BiJoiner[] {});
78+
}
79+
80+
/**
81+
* As defined by {@link ConstraintFactory#forEachUniquePair(Class, BiJoiner)},
82+
* with the additional change that the problem facts/entities are unfiltered.
83+
* <p>
84+
* For example,
85+
* <p>
86+
*
87+
* <pre>
88+
* precomputeFactory.forEachUnfilteredUniquePair(Shift.class, Joiners.equal(Shift::getLocation));
89+
* </pre>
90+
* <p>
91+
* Would roughly be equivalent to
92+
* <p>
93+
*
94+
* <pre>
95+
* constraintFactory.forEachUnfiltered(Shift.class)
96+
* .join(constraintFactory.forEachUnfiltered(Shift.class),
97+
* Joiners.lessThan(Shift::getId),
98+
* Joiners.equal(Shift::getLocation));
99+
* </pre>
100+
* <p>
101+
* Important: no variables can be referenced in any operations performed
102+
* by the returned {@link ConstraintStream}, otherwise a score corruption will
103+
* occur.
104+
* See the note in {@link ConstraintFactory#precompute(Function)} for
105+
* more details.
106+
*
107+
* @param <A> the type of the matched problem fact or {@link PlanningEntity planning entity}
108+
*/
109+
@SuppressWarnings("unchecked")
110+
default <A> BiConstraintStream<A, A> forEachUnfilteredUniquePair(Class<A> sourceClass, BiJoiner<A, A> joiner) {
111+
return forEachUnfilteredUniquePair(sourceClass, new BiJoiner[] { joiner });
112+
}
113+
114+
/**
115+
* As defined by {@link #forEachUnfilteredUniquePair(Class, BiJoiner)}.
116+
*
117+
* @param <A> the type of the matched problem fact or {@link PlanningEntity planning entity}
118+
* @return a stream that matches every unique combination of A and another A for which all the
119+
* {@link BiJoiner joiners} are true
120+
*/
121+
@SuppressWarnings("unchecked")
122+
default <A> BiConstraintStream<A, A> forEachUnfilteredUniquePair(Class<A> sourceClass, BiJoiner<A, A> joiner1,
123+
BiJoiner<A, A> joiner2) {
124+
return forEachUnfilteredUniquePair(sourceClass, new BiJoiner[] { joiner1, joiner2 });
125+
}
126+
127+
/**
128+
* As defined by {@link #forEachUnfilteredUniquePair(Class, BiJoiner)}.
129+
*
130+
* @param <A> the type of the matched problem fact or {@link PlanningEntity planning entity}
131+
* @return a stream that matches every unique combination of A and another A for which all the
132+
* {@link BiJoiner joiners} are true
133+
*/
134+
@SuppressWarnings("unchecked")
135+
default <A> BiConstraintStream<A, A> forEachUnfilteredUniquePair(Class<A> sourceClass, BiJoiner<A, A> joiner1,
136+
BiJoiner<A, A> joiner2, BiJoiner<A, A> joiner3) {
137+
return forEachUnfilteredUniquePair(sourceClass, new BiJoiner[] { joiner1, joiner2, joiner3 });
138+
}
139+
140+
/**
141+
* As defined by {@link #forEachUnfilteredUniquePair(Class, BiJoiner)}.
142+
*
143+
* @param <A> the type of the matched problem fact or {@link PlanningEntity planning entity}
144+
* @return a stream that matches every unique combination of A and another A for which all the
145+
* {@link BiJoiner joiners} are true
146+
*/
147+
@SuppressWarnings("unchecked")
148+
default <A> BiConstraintStream<A, A> forEachUnfilteredUniquePair(Class<A> sourceClass, BiJoiner<A, A> joiner1,
149+
BiJoiner<A, A> joiner2, BiJoiner<A, A> joiner3, BiJoiner<A, A> joiner4) {
150+
return forEachUnfilteredUniquePair(sourceClass, new BiJoiner[] { joiner1, joiner2, joiner3, joiner4 });
151+
}
152+
153+
/**
154+
* As defined by {@link #forEachUnfilteredUniquePair(Class, BiJoiner)}.
155+
* <p>
156+
* This method causes <i>Unchecked generics array creation for varargs parameter</i> warnings,
157+
* but we can't fix it with a {@link SafeVarargs} annotation because it's an interface method.
158+
* Therefore, there are overloaded methods with up to 4 {@link BiJoiner} parameters.
159+
*
160+
* @param <A> the type of the matched problem fact or {@link PlanningEntity planning entity}
161+
* @return a stream that matches every unique combination of A and another A for which all the
162+
* {@link BiJoiner joiners} are true
163+
*/
164+
<A> BiConstraintStream<A, A> forEachUnfilteredUniquePair(Class<A> sourceClass, BiJoiner<A, A>... joiners);
44165
}

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ public <A> UniConstraintStream<A> forEachUnfiltered(Class<A> sourceClass) {
178178
return forEachForCriteria(sourceClass, ForEachFilteringCriteria.ALL);
179179
}
180180

181-
<A> UniConstraintStream<A> forEachUnfilteredStatic(Class<A> sourceClass) {
181+
<A> UniConstraintStream<A> forEachUnfilteredPrecomputed(Class<A> sourceClass) {
182182
return forEachForCriteria(sourceClass, ForEachFilteringCriteria.ALL, RetrievalSemantics.PRECOMPUTE);
183183
}
184184

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
package ai.timefold.solver.core.impl.score.stream.bavet;
22

33
import ai.timefold.solver.core.api.score.stream.PrecomputeFactory;
4+
import ai.timefold.solver.core.api.score.stream.bi.BiConstraintStream;
5+
import ai.timefold.solver.core.api.score.stream.bi.BiJoiner;
46
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream;
57

68
public record BavetStaticDataFactory<Solution_>(
79
BavetConstraintFactory<Solution_> constraintFactory) implements PrecomputeFactory {
810
@Override
911
public <A> UniConstraintStream<A> forEachUnfiltered(Class<A> sourceClass) {
10-
return constraintFactory.forEachUnfilteredStatic(sourceClass);
12+
return constraintFactory.forEachUnfilteredPrecomputed(sourceClass);
13+
}
14+
15+
@Override
16+
public <A> BiConstraintStream<A, A> forEachUnfilteredUniquePair(Class<A> sourceClass, BiJoiner<A, A>... joiners) {
17+
return constraintFactory.forEachUniquePair(this::forEachUnfiltered, sourceClass, joiners);
1118
}
1219
}

core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/InnerConstraintFactory.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import ai.timefold.solver.core.api.score.stream.ConstraintProvider;
1919
import ai.timefold.solver.core.api.score.stream.bi.BiConstraintStream;
2020
import ai.timefold.solver.core.api.score.stream.bi.BiJoiner;
21+
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream;
2122
import ai.timefold.solver.core.impl.bavet.bi.joiner.BiJoinerComber;
2223
import ai.timefold.solver.core.impl.bavet.bi.joiner.DefaultBiJoiner;
2324
import ai.timefold.solver.core.impl.domain.common.accessor.MemberAccessor;
@@ -31,10 +32,16 @@ public abstract class InnerConstraintFactory<Solution_, Constraint_ extends Cons
3132
@Override
3233
public <A> @NonNull BiConstraintStream<A, A> forEachUniquePair(@NonNull Class<A> sourceClass,
3334
BiJoiner<A, A> @NonNull... joiners) {
35+
return forEachUniquePair(this::forEach, sourceClass, joiners);
36+
}
37+
38+
public <A> @NonNull BiConstraintStream<A, A> forEachUniquePair(Function<Class<A>, UniConstraintStream<A>> streamFunction,
39+
@NonNull Class<A> sourceClass,
40+
BiJoiner<A, A> @NonNull... joiners) {
3441
BiJoinerComber<A, A> joinerComber = BiJoinerComber.comb(joiners);
3542
joinerComber.addJoiner(buildLessThanId(sourceClass));
36-
return ((InnerUniConstraintStream<A>) forEach(sourceClass))
37-
.join(forEach(sourceClass), joinerComber);
43+
return ((InnerUniConstraintStream<A>) streamFunction.apply(sourceClass))
44+
.join(streamFunction.apply(sourceClass), joinerComber);
3845
}
3946

4047
private <A> DefaultBiJoiner<A, A> buildLessThanId(Class<A> sourceClass) {

core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamPrecomputeTest.java

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,169 @@ public void filter_1_changed() {
191191
assertMatch(value2, entity3));
192192
}
193193

194+
@TestTemplate
195+
public void filter_0_changed_forEachUnfilteredUniquePair() {
196+
var solution = TestdataLavishSolution.generateSolution();
197+
var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup");
198+
var valueGroup = new TestdataLavishValueGroup("MyValueGroup");
199+
solution.getEntityGroupList().add(entityGroup);
200+
solution.getValueGroupList().add(valueGroup);
201+
202+
var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup));
203+
solution.getValueList().add(value1);
204+
var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup));
205+
solution.getValueList().add(value2);
206+
var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null));
207+
solution.getValueList().add(value3);
208+
209+
var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1));
210+
solution.getEntityList().add(entity1);
211+
var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1);
212+
solution.getEntityList().add(entity2);
213+
var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(),
214+
value1);
215+
solution.getEntityList().add(entity3);
216+
217+
var scoreDirector =
218+
buildScoreDirector(factory -> factory
219+
.precompute(data -> data.forEachUnfilteredUniquePair(TestdataLavishEntity.class,
220+
Joiners.equal(TestdataLavishEntity::getEntityGroup)))
221+
.filter((a, b) -> a.getValue() == value1)
222+
.penalize(SimpleScore.ONE)
223+
.asConstraint(TEST_CONSTRAINT_NAME));
224+
225+
// From scratch
226+
Mockito.reset(entity1);
227+
scoreDirector.setWorkingSolution(solution);
228+
assertScore(scoreDirector,
229+
assertMatch(entity1, entity2));
230+
Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup();
231+
232+
// Incrementally update a variable
233+
Mockito.reset(entity1);
234+
scoreDirector.beforeVariableChanged(entity1, "value");
235+
entity1.setValue(solution.getFirstValue());
236+
scoreDirector.afterVariableChanged(entity1, "value");
237+
assertScore(scoreDirector);
238+
Mockito.verify(entity1, Mockito.never()).getEntityGroup();
239+
240+
// Incrementally update a variable
241+
Mockito.reset(entity1);
242+
scoreDirector.beforeVariableChanged(entity1, "value");
243+
entity1.setValue(value1);
244+
scoreDirector.afterVariableChanged(entity1, "value");
245+
assertScore(scoreDirector,
246+
assertMatch(entity1, entity2));
247+
Mockito.verify(entity1, Mockito.never()).getEntityGroup();
248+
249+
// Incrementally update a fact
250+
scoreDirector.beforeProblemPropertyChanged(entity3);
251+
entity3.setEntityGroup(entityGroup);
252+
scoreDirector.afterProblemPropertyChanged(entity3);
253+
assertScore(scoreDirector,
254+
assertMatch(entity1, entity2),
255+
assertMatch(entity1, entity3),
256+
assertMatch(entity2, entity3));
257+
258+
// Remove entity
259+
scoreDirector.beforeEntityRemoved(entity3);
260+
solution.getEntityList().remove(entity3);
261+
scoreDirector.afterEntityRemoved(entity3);
262+
assertScore(scoreDirector,
263+
assertMatch(entity1, entity2));
264+
265+
// Add it back again, to make sure it was properly removed before
266+
scoreDirector.beforeEntityAdded(entity3);
267+
solution.getEntityList().add(entity3);
268+
scoreDirector.afterEntityAdded(entity3);
269+
assertScore(scoreDirector,
270+
assertMatch(entity1, entity2),
271+
assertMatch(entity1, entity3),
272+
assertMatch(entity2, entity3));
273+
}
274+
275+
@TestTemplate
276+
public void filter_1_changed_forEachUnfilteredUniquePair() {
277+
var solution = TestdataLavishSolution.generateSolution();
278+
var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup");
279+
var valueGroup = new TestdataLavishValueGroup("MyValueGroup");
280+
solution.getEntityGroupList().add(entityGroup);
281+
solution.getValueGroupList().add(valueGroup);
282+
283+
var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup));
284+
solution.getValueList().add(value1);
285+
var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup));
286+
solution.getValueList().add(value2);
287+
var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null));
288+
solution.getValueList().add(value3);
289+
290+
var entity1 = new TestdataLavishEntity("MyEntity 1", entityGroup, value1);
291+
solution.getEntityList().add(entity1);
292+
var entity2 = Mockito.spy(new TestdataLavishEntity("MyEntity 2", entityGroup, value1));
293+
solution.getEntityList().add(entity2);
294+
var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(),
295+
value2);
296+
solution.getEntityList().add(entity3);
297+
298+
var scoreDirector =
299+
buildScoreDirector(factory -> factory
300+
.precompute(data -> data.forEachUnfilteredUniquePair(TestdataLavishEntity.class,
301+
Joiners.equal(TestdataLavishEntity::getEntityGroup)))
302+
.filter((a, b) -> b.getValue() == value1)
303+
.penalize(SimpleScore.ONE)
304+
.asConstraint(TEST_CONSTRAINT_NAME));
305+
306+
// From scratch
307+
Mockito.reset(entity2);
308+
scoreDirector.setWorkingSolution(solution);
309+
assertScore(scoreDirector,
310+
assertMatch(entity1, entity2));
311+
Mockito.verify(entity2, Mockito.atLeastOnce()).getEntityGroup();
312+
313+
// Incrementally update a variable
314+
Mockito.reset(entity2);
315+
scoreDirector.beforeVariableChanged(entity2, "value");
316+
entity2.setValue(solution.getFirstValue());
317+
scoreDirector.afterVariableChanged(entity2, "value");
318+
assertScore(scoreDirector);
319+
Mockito.verify(entity2, Mockito.never()).getEntityGroup();
320+
321+
// Incrementally update a variable
322+
Mockito.reset(entity2);
323+
scoreDirector.beforeVariableChanged(entity2, "value");
324+
entity2.setValue(value1);
325+
scoreDirector.afterVariableChanged(entity2, "value");
326+
assertScore(scoreDirector,
327+
assertMatch(entity1, entity2));
328+
Mockito.verify(entity2, Mockito.never()).getEntityGroup();
329+
330+
// Incrementally update a fact
331+
scoreDirector.beforeProblemPropertyChanged(entity3);
332+
entity3.setValue(value1);
333+
entity3.setEntityGroup(entityGroup);
334+
scoreDirector.afterProblemPropertyChanged(entity3);
335+
assertScore(scoreDirector,
336+
assertMatch(entity1, entity2),
337+
assertMatch(entity1, entity3),
338+
assertMatch(entity2, entity3));
339+
340+
// Remove entity
341+
scoreDirector.beforeEntityRemoved(entity3);
342+
solution.getEntityList().remove(entity3);
343+
scoreDirector.afterEntityRemoved(entity3);
344+
assertScore(scoreDirector,
345+
assertMatch(entity1, entity2));
346+
347+
// Add it back again, to make sure it was properly removed before
348+
scoreDirector.beforeEntityAdded(entity3);
349+
solution.getEntityList().add(entity3);
350+
scoreDirector.afterEntityAdded(entity3);
351+
assertScore(scoreDirector,
352+
assertMatch(entity1, entity2),
353+
assertMatch(entity1, entity3),
354+
assertMatch(entity2, entity3));
355+
}
356+
194357
private <A, B> void assertPrecompute(TestdataLavishSolution solution,
195358
List<Pair<A, B>> expectedValues,
196359
Function<PrecomputeFactory, BiConstraintStream<A, B>> entityStreamSupplier) {

0 commit comments

Comments
 (0)