Skip to content

Commit 706173f

Browse files
authored
feat: allow shadow variable updaters to read the solution (#2029)
1 parent b156174 commit 706173f

File tree

23 files changed

+582
-29
lines changed

23 files changed

+582
-29
lines changed

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/ChangedVariableNotifier.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ public CollectionInverseVariableSupply getCollectionInverseVariableSupply(Variab
3232
}
3333
}
3434

35+
public @Nullable Solution_ getWorkingSolution() {
36+
return innerScoreDirector != null ? innerScoreDirector.getWorkingSolution() : null;
37+
}
38+
3539
@SuppressWarnings("unchecked")
3640
public static <Solution_> ChangedVariableNotifier<Solution_> empty() {
3741
return (ChangedVariableNotifier<Solution_>) EMPTY;

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DeclarativeShadowVariableDescriptor.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,20 @@ public void processAnnotations(DescriptorPolicy descriptorPolicy) {
4444
throw new IllegalStateException("DeclarativeShadowVariableDescriptor was created when method is empty.");
4545
}
4646

47+
var solutionClass = entityDescriptor.getSolutionDescriptor().getSolutionClass();
4748
var method = ReflectionHelper.getDeclaredMethod(variableMemberAccessor.getDeclaringClass(), methodName);
49+
if (method == null) {
50+
// Retry with the solution class
51+
method = ReflectionHelper.getDeclaredMethod(variableMemberAccessor.getDeclaringClass(), methodName, solutionClass);
52+
}
4853

4954
if (method == null) {
5055
throw new IllegalArgumentException("""
51-
@%s (%s) defines a supplierMethod (%s) that does not exist inside its declaring class (%s).
52-
Maybe you misspelled the supplierMethod name?"""
56+
@%s (%s) defines a supplierName (%s) that does not exist inside its declaring class (%s).
57+
Maybe you included a parameter which is not a planning solution (%s)?
58+
Maybe you misspelled the supplierName name?"""
5359
.formatted(ShadowVariable.class.getSimpleName(), variableName, methodName,
54-
variableMemberAccessor.getDeclaringClass().getCanonicalName()));
60+
variableMemberAccessor.getDeclaringClass().getCanonicalName(), solutionClass.getName()));
5561
}
5662

5763
var shadowVariableUpdater = method.getAnnotation(ShadowSources.class);
@@ -65,7 +71,8 @@ public void processAnnotations(DescriptorPolicy descriptorPolicy) {
6571
}
6672
this.calculator =
6773
entityDescriptor.getSolutionDescriptor().getMemberAccessorFactory().buildAndCacheMemberAccessor(method,
68-
MemberAccessorFactory.MemberAccessorType.FIELD_OR_READ_METHOD, ShadowSources.class,
74+
MemberAccessorFactory.MemberAccessorType.FIELD_OR_READ_METHOD_WITH_OPTIONAL_PARAMETER,
75+
ShadowSources.class,
6976
descriptorPolicy.getDomainAccessType());
7077

7178
sourcePaths = shadowVariableUpdater.value();

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSessionFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ public List<VariableUpdaterInfo<Solution_>> getUpdatersForEntityVariable(Object
399399
consistencyTracker.getDeclarativeEntityConsistencyState(
400400
declarativeShadowVariableDescriptor.getEntityDescriptor()),
401401
declarativeShadowVariableDescriptor.getMemberAccessor(),
402-
declarativeShadowVariableDescriptor.getCalculator()::executeGetter);
402+
declarativeShadowVariableDescriptor.getCalculator());
403403
if (declarativeShadowVariableDescriptor.getAlignmentKeyMap() != null) {
404404
var alignmentKeyFunction = declarativeShadowVariableDescriptor.getAlignmentKeyMap();
405405
var alignmentKeyToAlignedEntitiesMap = new HashMap<Object, List<Object>>();
@@ -492,7 +492,7 @@ private static <Solution_> VariableReferenceGraph buildArbitrarySingleEntityGrap
492492
graphDescriptor.consistencyTracker().getDeclarativeEntityConsistencyState(
493493
declarativeShadowVariableDescriptor.getEntityDescriptor()),
494494
declarativeShadowVariableDescriptor.getMemberAccessor(),
495-
declarativeShadowVariableDescriptor.getCalculator()::executeGetter)));
495+
declarativeShadowVariableDescriptor.getCalculator())));
496496
}
497497

498498
private static <Solution_> Map<VariableMetaModel<?, ?, ?>, Set<VariableSourceReference>> createGraphNodes(

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/SingleDirectionalParentVariableReferenceGraph.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public SingleDirectionalParentVariableReferenceGraph(
5858
variableDescriptor,
5959
entityConsistencyState,
6060
variableDescriptor.getMemberAccessor(),
61-
variableDescriptor.getCalculator()::executeGetter);
61+
variableDescriptor.getCalculator());
6262
sortedVariableUpdaterInfos[updaterIndex++] = variableUpdaterInfo;
6363

6464
for (var source : variableDescriptor.getSources()) {

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableUpdaterInfo.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import java.util.Arrays;
44
import java.util.Objects;
5-
import java.util.function.Function;
5+
import java.util.function.BiFunction;
66

77
import ai.timefold.solver.core.impl.domain.common.accessor.MemberAccessor;
88
import ai.timefold.solver.core.preview.api.domain.metamodel.VariableMetaModel;
@@ -17,16 +17,20 @@ public record VariableUpdaterInfo<Solution_>(
1717
DeclarativeShadowVariableDescriptor<Solution_> variableDescriptor,
1818
EntityConsistencyState<Solution_, Object> entityConsistencyState,
1919
MemberAccessor memberAccessor,
20-
Function<Object, Object> calculator,
20+
BiFunction<@Nullable Solution_, Object, Object> calculator,
2121
@Nullable Object[] groupEntities) {
2222

2323
public VariableUpdaterInfo(VariableMetaModel<Solution_, ?, ?> id,
2424
int groupId,
2525
DeclarativeShadowVariableDescriptor<Solution_> variableDescriptor,
2626
EntityConsistencyState<Solution_, Object> entityConsistencyState,
2727
MemberAccessor memberAccessor,
28-
Function<Object, Object> calculator) {
29-
this(id, groupId, variableDescriptor, entityConsistencyState, memberAccessor, calculator, null);
28+
MemberAccessor calculatorAccessor) {
29+
this(id, groupId, variableDescriptor, entityConsistencyState, memberAccessor,
30+
calculatorAccessor.getGetterMethodParameterType() != null
31+
? (solution, entity) -> calculatorAccessor.executeGetter(entity, solution)
32+
: (_ignore, entity) -> calculatorAccessor.executeGetter(entity),
33+
null);
3034
}
3135

3236
public VariableUpdaterInfo<Solution_> withGroupId(int groupId) {
@@ -40,7 +44,8 @@ public VariableUpdaterInfo<Solution_> withGroupEntities(Object[] groupEntities)
4044
}
4145

4246
public boolean updateIfChanged(Object entity, ChangedVariableNotifier<Solution_> changedVariableNotifier) {
43-
return updateIfChanged(entity, calculator.apply(entity), changedVariableNotifier);
47+
return updateIfChanged(entity, calculator.apply(changedVariableNotifier.getWorkingSolution(), entity),
48+
changedVariableNotifier);
4449
}
4550

4651
public boolean updateIfChanged(Object entity, @Nullable Object newValue,

core/src/test/java/ai/timefold/solver/core/impl/domain/solution/descriptor/SolutionDescriptorTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,12 +420,13 @@ void testBadChainedAndListModel() {
420420
}
421421

422422
@Test
423-
void missingDeclarativeSupplierMethod() {
423+
void missingDeclarativeSupplierName() {
424424
assertThatCode(TestdataDeclarativeMissingSupplierSolution::buildSolutionDescriptor)
425425
.hasMessageContainingAll("@ShadowVariable (endTime)",
426-
"supplierMethod (calculateEndTime) that does not exist",
426+
"supplierName (calculateEndTime) that does not exist",
427427
"inside its declaring class (ai.timefold.solver.core.testdomain.shadow.missing.TestdataDeclarativeMissingSupplierValue).",
428-
"Maybe you misspelled the supplierMethod name?");
428+
"Maybe you included a parameter which is not a planning solution (ai.timefold.solver.core.testdomain.shadow.missing.TestdataDeclarativeMissingSupplierSolution)",
429+
"Maybe you misspelled the supplierName name?");
429430
}
430431

431432
@Test

core/src/test/java/ai/timefold/solver/core/impl/domain/variable/ShadowVariableUpdateTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import ai.timefold.solver.core.testdomain.shadow.full.TestdataShadowedFullEntity;
1919
import ai.timefold.solver.core.testdomain.shadow.full.TestdataShadowedFullSolution;
2020
import ai.timefold.solver.core.testdomain.shadow.full.TestdataShadowedFullValue;
21+
import ai.timefold.solver.core.testdomain.shadow.parameter.TestdataBasicVarParameterEntity;
22+
import ai.timefold.solver.core.testdomain.shadow.parameter.TestdataBasicVarParameterSolution;
23+
import ai.timefold.solver.core.testdomain.shadow.parameter.TestdataBasicVarParameterValue;
2124

2225
import org.assertj.core.api.Assertions;
2326
import org.junit.jupiter.api.Test;
@@ -119,6 +122,28 @@ void solutionUpdateBasicShadowVariables() {
119122
assertThat(value2.getEndTime()).isEqualTo(value2.getStartTime().plus(value2.getDuration()));
120123
}
121124

125+
@Test
126+
void solutionUpdateBasicShadowVariablesWithParameter() {
127+
var value1 = new TestdataBasicVarParameterValue("v1", Duration.ofSeconds(10));
128+
var value2 = new TestdataBasicVarParameterValue("v2", Duration.ofSeconds(20));
129+
var entity1 = new TestdataBasicVarParameterEntity("e1", value1);
130+
var entity2 = new TestdataBasicVarParameterEntity("e2", value2);
131+
var entity3 = new TestdataBasicVarParameterEntity("e3", value1);
132+
var solution = new TestdataBasicVarParameterSolution();
133+
solution.setEntities(List.of(entity1, entity2, entity3));
134+
solution.setValues(List.of(value1, value2));
135+
solution.setProblemFacts(List.of(1, "Data", new Object()));
136+
SolutionManager.updateShadowVariables(solution);
137+
assertThat(value1.getEntityList()).containsExactly(entity1, entity3);
138+
assertThat(value2.getEntityList()).containsExactly(entity2);
139+
assertThat(value1.getStartTime())
140+
.isEqualTo(TestdataBasicVarParameterValue.DEFAULT_TIME.plusDays(value1.getEntityList().size()));
141+
assertThat(value1.getEndTime()).isEqualTo(value1.getStartTime().plus(value1.getDuration()));
142+
assertThat(value2.getStartTime())
143+
.isEqualTo(TestdataBasicVarParameterValue.DEFAULT_TIME.plusDays(value2.getEntityList().size()));
144+
assertThat(value2.getEndTime()).isEqualTo(value2.getStartTime().plus(value2.getDuration()));
145+
}
146+
122147
@Test
123148
void updateChainedShadowVariables() {
124149
var value1 = new TestdataChainedVarValue("v1", Duration.ofDays(10));

core/src/test/java/ai/timefold/solver/core/impl/domain/variable/declarative/RootVariableSourceTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
import ai.timefold.solver.core.testdomain.shadow.invalid.TestdataInvalidDeclarativeEntity;
2828
import ai.timefold.solver.core.testdomain.shadow.invalid.TestdataInvalidDeclarativeSolution;
2929
import ai.timefold.solver.core.testdomain.shadow.invalid.TestdataInvalidDeclarativeValue;
30+
import ai.timefold.solver.core.testdomain.shadow.invalid.parameter.TestdataInvalidDeclarativeParameterEntity;
31+
import ai.timefold.solver.core.testdomain.shadow.invalid.parameter.TestdataInvalidDeclarativeParameterSolution;
32+
import ai.timefold.solver.core.testdomain.shadow.invalid.parameter.TestdataInvalidDeclarativeParameterValue;
3033

3134
import org.junit.jupiter.api.Test;
3235

@@ -624,6 +627,16 @@ void invalidPathMultipleFactsInARow() {
624627
" in a row.");
625628
}
626629

630+
@Test
631+
void invalidParameter() {
632+
assertThatCode(() -> SolutionDescriptor.buildSolutionDescriptor(
633+
TestdataInvalidDeclarativeParameterSolution.class, TestdataInvalidDeclarativeParameterEntity.class,
634+
TestdataInvalidDeclarativeParameterValue.class)
635+
.getMetaModel())
636+
.hasMessageContaining(
637+
"Maybe you included a parameter which is not a planning solution (ai.timefold.solver.core.testdomain.shadow.invalid.parameter.TestdataInvalidDeclarativeParameterSolution)?");
638+
}
639+
627640
@Test
628641
void preferGetterWhenFieldTheSameType() {
629642
record TestClass(String name) {

core/src/test/java/ai/timefold/solver/core/testdomain/shadow/counting/TestdataCountingValue.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void setCount(Integer count) {
5252
}
5353

5454
@ShadowSources({ "previous.count", "entity" })
55-
public Integer countSupplier() {
55+
public Integer countSupplier(TestdataCountingSolution solution) {
5656
if (calledCount != 0) {
5757
throw new IllegalStateException("Supplier for entity %s was already called."
5858
.formatted(entity));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package ai.timefold.solver.core.testdomain.shadow.invalid.parameter;
2+
3+
import java.util.List;
4+
5+
import ai.timefold.solver.core.api.domain.entity.PlanningEntity;
6+
import ai.timefold.solver.core.api.domain.variable.PlanningListVariable;
7+
import ai.timefold.solver.core.testdomain.TestdataObject;
8+
9+
@PlanningEntity
10+
public class TestdataInvalidDeclarativeParameterEntity extends TestdataObject {
11+
@PlanningListVariable
12+
List<TestdataInvalidDeclarativeParameterValue> values;
13+
14+
public TestdataInvalidDeclarativeParameterEntity() {
15+
}
16+
17+
public TestdataInvalidDeclarativeParameterEntity(String code) {
18+
super(code);
19+
}
20+
21+
}

0 commit comments

Comments
 (0)