Skip to content

Commit 1a1909c

Browse files
perf: group declarative variables after aligned variables when possible (#1725)
1 parent ee2410f commit 1a1909c

File tree

4 files changed

+113
-60
lines changed

4 files changed

+113
-60
lines changed

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSessionFactory.java

Lines changed: 73 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,14 @@ private record GroupVariableUpdaterInfo<Solution_>(
189189
List<DeclarativeShadowVariableDescriptor<Solution_>> sortedDeclarativeVariableDescriptors,
190190
List<VariableUpdaterInfo<Solution_>> allUpdaters,
191191
List<VariableUpdaterInfo<Solution_>> groupedUpdaters,
192-
Map<DeclarativeShadowVariableDescriptor<Solution_>, Map<Object, VariableUpdaterInfo<Solution_>>> variableToEntityToGroupUpdater) {
192+
Map<DeclarativeShadowVariableDescriptor<Solution_>, Map<Object, List<VariableUpdaterInfo<Solution_>>>> variableToEntityToGroupUpdater) {
193193

194194
public List<VariableUpdaterInfo<Solution_>> getUpdatersForEntityVariable(Object entity,
195195
DeclarativeShadowVariableDescriptor<Solution_> declarativeShadowVariableDescriptor) {
196196
if (variableToEntityToGroupUpdater.containsKey(declarativeShadowVariableDescriptor)) {
197-
var updater = variableToEntityToGroupUpdater.get(declarativeShadowVariableDescriptor).get(entity);
198-
if (updater != null) {
199-
return List.of(updater);
197+
var updaters = variableToEntityToGroupUpdater.get(declarativeShadowVariableDescriptor).get(entity);
198+
if (updaters != null) {
199+
return updaters;
200200
}
201201
}
202202
for (var shadowVariableDescriptor : sortedDeclarativeVariableDescriptors) {
@@ -225,24 +225,7 @@ public List<VariableUpdaterInfo<Solution_>> getUpdatersForEntityVariable(Object
225225
var groupVariables = new ArrayList<DeclarativeShadowVariableDescriptor<Solution_>>();
226226
groupIndexToVariables.put(0, groupVariables);
227227
for (var declarativeShadowVariableDescriptor : sortedDeclarativeVariableDescriptors) {
228-
// If a @ShadowSources has a group source (i.e. "visitGroup[].arrivalTimes"),
229-
// create a new group since it must wait until all members of that group are processed
230-
var hasGroupSources = Arrays.stream(declarativeShadowVariableDescriptor.getSources())
231-
.anyMatch(rootVariableSource -> rootVariableSource.parentVariableType() == ParentVariableType.GROUP);
232-
233-
// If a @ShadowSources has an alignment key,
234-
// create a new group since multiple entities must be updated for this node
235-
var hasAlignmentKey = declarativeShadowVariableDescriptor.getAlignmentKeyMap() != null;
236-
237-
// If the previous @ShadowSources has an alignment key,
238-
// create a new group since we are updating a single entity again
239-
// NOTE: Can potentially be optimized/share a node if VariableUpdaterInfo
240-
// update each group member independently after the alignmentKey
241-
var previousHasAlignmentKey = !groupVariables.isEmpty() && groupVariables.get(0).getAlignmentKeyMap() != null;
242-
243-
if (!groupVariables.isEmpty() && (hasGroupSources
244-
|| hasAlignmentKey
245-
|| previousHasAlignmentKey)) {
228+
if (shouldCreateNewGroupForVariable(declarativeShadowVariableDescriptor, groupVariables)) {
246229
groupVariables = new ArrayList<>();
247230
groupIndexToVariables.put(groupIndexToVariables.size(), groupVariables);
248231
}
@@ -252,11 +235,12 @@ public List<VariableUpdaterInfo<Solution_>> getUpdatersForEntityVariable(Object
252235
var out = new HashMap<VariableMetaModel<Solution_, ?, ?>, GroupVariableUpdaterInfo<Solution_>>();
253236
var allUpdaters = new ArrayList<VariableUpdaterInfo<Solution_>>();
254237
var groupedUpdaters =
255-
new HashMap<DeclarativeShadowVariableDescriptor<Solution_>, Map<Object, VariableUpdaterInfo<Solution_>>>();
238+
new HashMap<DeclarativeShadowVariableDescriptor<Solution_>, Map<Object, List<VariableUpdaterInfo<Solution_>>>>();
256239
var updaterKey = 0;
257240
for (var entryKey = 0; entryKey < groupIndexToVariables.size(); entryKey++) {
258241
var entryGroupVariables = groupIndexToVariables.get(entryKey);
259242
var updaters = new ArrayList<VariableUpdaterInfo<Solution_>>();
243+
var alignmentKeyToGroupIndex = new HashMap<Object, Integer>();
260244
for (var declarativeShadowVariableDescriptor : entryGroupVariables) {
261245
var updater = new VariableUpdaterInfo<>(
262246
declarativeShadowVariableDescriptor.getVariableMetaModel(),
@@ -265,31 +249,11 @@ public List<VariableUpdaterInfo<Solution_>> getUpdatersForEntityVariable(Object
265249
declarativeShadowVariableDescriptor.getEntityDescriptor().getShadowVariableLoopedDescriptor(),
266250
declarativeShadowVariableDescriptor.getMemberAccessor(),
267251
declarativeShadowVariableDescriptor.getCalculator()::executeGetter);
268-
if (declarativeShadowVariableDescriptor.getAlignmentKeyMap() != null) {
269-
var alignmentKeyFunction = declarativeShadowVariableDescriptor.getAlignmentKeyMap();
270-
var alignmentKeyToAlignedEntitiesMap = new HashMap<Object, List<Object>>();
271-
for (var entity : entities) {
272-
if (declarativeShadowVariableDescriptor.getEntityDescriptor().getEntityClass().isInstance(entity)) {
273-
var alignmentKey = alignmentKeyFunction.apply(entity);
274-
alignmentKeyToAlignedEntitiesMap.computeIfAbsent(alignmentKey, k -> new ArrayList<>()).add(entity);
275-
}
276-
}
277-
for (var alignmentGroup : alignmentKeyToAlignedEntitiesMap.entrySet()) {
278-
var updaterCopy = updater.withGroupId(updaterKey);
279-
if (alignmentGroup.getKey() == null) {
280-
updaters.add(updaterCopy);
281-
allUpdaters.add(updaterCopy);
282-
} else {
283-
updaterCopy = updaterCopy.withGroupEntities(alignmentGroup.getValue().toArray(new Object[0]));
284-
var variableUpdaterMap = groupedUpdaters.computeIfAbsent(declarativeShadowVariableDescriptor,
285-
ignored -> new IdentityHashMap<>());
286-
for (var entity : alignmentGroup.getValue()) {
287-
variableUpdaterMap.put(entity, updaterCopy);
288-
}
289-
}
290-
updaterKey++;
291-
}
292-
updaterKey--; // it will be incremented again at end of the loop
252+
if (entryGroupVariables.get(0).getAlignmentKeyMap() != null) {
253+
updaterKey = processAlignmentGroupVariableAndGetNextUpdaterKey(entities,
254+
declarativeShadowVariableDescriptor, entryGroupVariables, updater,
255+
updaterKey, updaters,
256+
allUpdaters, alignmentKeyToGroupIndex, groupedUpdaters);
293257
} else {
294258
updaters.add(updater);
295259
allUpdaters.add(updater);
@@ -301,12 +265,72 @@ public List<VariableUpdaterInfo<Solution_>> getUpdatersForEntityVariable(Object
301265
for (var declarativeShadowVariableDescriptor : entryGroupVariables) {
302266
out.put(declarativeShadowVariableDescriptor.getVariableMetaModel(), groupVariableUpdaterInfo);
303267
}
304-
updaterKey++;
268+
if (entryGroupVariables.get(0).getAlignmentKeyMap() == null) {
269+
updaterKey++;
270+
}
305271
}
306272
allUpdaters.replaceAll(updater -> updater.withGroupId(groupIndexToVariables.size()));
307273
return out;
308274
}
309275

276+
private static <Solution_> boolean shouldCreateNewGroupForVariable(
277+
DeclarativeShadowVariableDescriptor<Solution_> declarativeShadowVariableDescriptor,
278+
List<DeclarativeShadowVariableDescriptor<Solution_>> groupVariables) {
279+
// If a @ShadowSources has a group source (i.e. "visitGroup[].arrivalTimes"),
280+
// create a new group since it must wait until all members of that group are processed
281+
var hasGroupSources = Arrays.stream(declarativeShadowVariableDescriptor.getSources())
282+
.anyMatch(rootVariableSource -> rootVariableSource.parentVariableType() == ParentVariableType.GROUP);
283+
284+
// If a @ShadowSources has an alignment key,
285+
// create a new group since multiple entities must be updated for this node
286+
var alignmentKey = declarativeShadowVariableDescriptor.getAlignmentKeyName();
287+
var previousAlignmentKey = groupVariables.isEmpty() ? null : groupVariables.get(0).getAlignmentKeyName();
288+
289+
return !groupVariables.isEmpty() && (hasGroupSources
290+
|| (alignmentKey != null && !Objects.equals(alignmentKey, previousAlignmentKey)));
291+
}
292+
293+
private static <Solution_> int processAlignmentGroupVariableAndGetNextUpdaterKey(Object[] entities,
294+
DeclarativeShadowVariableDescriptor<Solution_> declarativeShadowVariableDescriptor,
295+
List<DeclarativeShadowVariableDescriptor<Solution_>> entryGroupVariables, VariableUpdaterInfo<Solution_> updater,
296+
int updaterKey, List<VariableUpdaterInfo<Solution_>> updaters,
297+
List<VariableUpdaterInfo<Solution_>> allUpdaters, HashMap<Object, Integer> alignmentKeyToGroupIndex,
298+
Map<DeclarativeShadowVariableDescriptor<Solution_>, Map<Object, List<VariableUpdaterInfo<Solution_>>>> groupedUpdaters) {
299+
var alignmentKeyFunction = entryGroupVariables.get(0).getAlignmentKeyMap();
300+
var alignmentKeyToAlignedEntitiesMap = new HashMap<Object, List<Object>>();
301+
for (var entity : entities) {
302+
if (declarativeShadowVariableDescriptor.getEntityDescriptor().getEntityClass().isInstance(entity)) {
303+
var alignmentKey = alignmentKeyFunction.apply(entity);
304+
alignmentKeyToAlignedEntitiesMap.computeIfAbsent(alignmentKey, k -> new ArrayList<>()).add(entity);
305+
}
306+
}
307+
for (var alignmentGroup : alignmentKeyToAlignedEntitiesMap.entrySet()) {
308+
if (alignmentGroup.getKey() == null) {
309+
var updaterCopy = updater.withGroupId(updaterKey);
310+
updaters.add(updaterCopy);
311+
allUpdaters.add(updaterCopy);
312+
updaterKey++;
313+
} else {
314+
final var newAlignmentUpdaterKey = updaterKey;
315+
final var alignmentUpdaterKey = (int) alignmentKeyToGroupIndex.computeIfAbsent(alignmentGroup.getKey(),
316+
ignored -> newAlignmentUpdaterKey);
317+
if (alignmentUpdaterKey == newAlignmentUpdaterKey) {
318+
updaterKey++;
319+
}
320+
var updaterCopy = updater.withGroupId(alignmentUpdaterKey);
321+
updaterCopy = updaterCopy.withGroupEntities(alignmentGroup.getValue().toArray(new Object[0]),
322+
declarativeShadowVariableDescriptor.getAlignmentKeyName() != null);
323+
var variableUpdaterMap = groupedUpdaters.computeIfAbsent(declarativeShadowVariableDescriptor,
324+
ignored -> groupedUpdaters.getOrDefault(entryGroupVariables.get(0),
325+
new IdentityHashMap<>()));
326+
for (var entity : alignmentGroup.getValue()) {
327+
variableUpdaterMap.computeIfAbsent(entity, ignored -> new ArrayList<>()).add(updaterCopy);
328+
}
329+
}
330+
}
331+
return updaterKey;
332+
}
333+
310334
private static <Solution_> VariableReferenceGraph buildArbitrarySingleEntityGraph(
311335
SolutionDescriptor<Solution_> solutionDescriptor,
312336
VariableReferenceGraphBuilder<Solution_> variableReferenceGraphBuilder, Object[] entities,

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultVariableReferenceGraph.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@ public DefaultVariableReferenceGraph(VariableReferenceGraphBuilder<Solution_> ou
1919

2020
var entityToVariableReferenceMap = new IdentityHashMap<Object, List<GraphNode<Solution_>>>();
2121
for (var instance : nodeList) {
22-
var entity = instance.entity();
23-
entityToVariableReferenceMap.computeIfAbsent(entity, ignored -> new ArrayList<>())
24-
.add(instance);
22+
if (instance.groupEntityIds() == null) {
23+
var entity = instance.entity();
24+
entityToVariableReferenceMap.computeIfAbsent(entity, ignored -> new ArrayList<>())
25+
.add(instance);
26+
} else {
27+
for (var groupEntity : instance.variableReferences().get(0).groupEntities()) {
28+
entityToVariableReferenceMap.computeIfAbsent(groupEntity, ignored -> new ArrayList<>())
29+
.add(instance);
30+
}
31+
}
2532
}
2633
// Immutable optimized version of the map, now that it won't be updated anymore.
2734
var immutableEntityToVariableReferenceMap = mapOfListsDeepCopyOf(entityToVariableReferenceMap);

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableUpdaterInfo.java

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,48 @@ public record VariableUpdaterInfo<Solution_>(
1818
@Nullable ShadowVariableLoopedVariableDescriptor<Solution_> shadowVariableLoopedDescriptor,
1919
MemberAccessor memberAccessor,
2020
Function<Object, Object> calculator,
21-
@Nullable Object[] groupEntities) {
21+
@Nullable Object[] groupEntities,
22+
boolean isGroupAligned) {
2223

2324
public VariableUpdaterInfo(VariableMetaModel<Solution_, ?, ?> id,
2425
int groupId,
2526
DeclarativeShadowVariableDescriptor<Solution_> variableDescriptor,
2627
@Nullable ShadowVariableLoopedVariableDescriptor<Solution_> shadowVariableLoopedDescriptor,
2728
MemberAccessor memberAccessor,
2829
Function<Object, Object> calculator) {
29-
this(id, groupId, variableDescriptor, shadowVariableLoopedDescriptor, memberAccessor, calculator, null);
30+
// isGroupAligned defaults to true, so we can just check it instead of checking
31+
// if groupEntities is null before determining what updateIfChanged to call
32+
this(id, groupId, variableDescriptor, shadowVariableLoopedDescriptor, memberAccessor, calculator, null, true);
3033
}
3134

3235
public VariableUpdaterInfo<Solution_> withGroupId(int groupId) {
3336
return new VariableUpdaterInfo<>(id, groupId, variableDescriptor, shadowVariableLoopedDescriptor, memberAccessor,
34-
calculator, groupEntities);
37+
calculator, groupEntities, isGroupAligned);
3538
}
3639

37-
public VariableUpdaterInfo<Solution_> withGroupEntities(Object[] groupEntities) {
40+
public VariableUpdaterInfo<Solution_> withGroupEntities(Object[] groupEntities, boolean isGroupAligned) {
3841
return new VariableUpdaterInfo<>(id, groupId, variableDescriptor, shadowVariableLoopedDescriptor, memberAccessor,
39-
calculator, groupEntities);
42+
calculator, groupEntities, isGroupAligned);
4043
}
4144

4245
public boolean updateIfChanged(Object entity, ChangedVariableNotifier<Solution_> changedVariableNotifier) {
43-
return updateIfChanged(entity, calculator.apply(entity), changedVariableNotifier);
46+
if (isGroupAligned) {
47+
return updateIfChanged(entity, calculator.apply(entity), changedVariableNotifier);
48+
} else {
49+
var anyChanged = false;
50+
for (var groupEntity : groupEntities) {
51+
var oldValue = variableDescriptor.getValue(groupEntity);
52+
var newValue = calculator.apply(groupEntity);
53+
54+
if (!Objects.equals(oldValue, newValue)) {
55+
changedVariableNotifier.beforeVariableChanged().accept(variableDescriptor, groupEntity);
56+
variableDescriptor.setValue(groupEntity, newValue);
57+
changedVariableNotifier.afterVariableChanged().accept(variableDescriptor, groupEntity);
58+
anyChanged = true;
59+
}
60+
}
61+
return anyChanged;
62+
}
4463
}
4564

4665
public boolean updateIfChanged(Object entity, @Nullable Object newValue,

core/src/test/java/ai/timefold/solver/core/impl/domain/variable/listener/support/VariableListenerSupportTest.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,15 @@ void shadowVariableListGraphEvents() {
297297
// which is the first member of the group
298298
for (var element : visit.getConcurrentValueGroup()) {
299299
verifyAddEdge.accept(serviceReadyTime, element, serviceStartTime, visit);
300-
verifyAddEdge.accept(serviceStartTime, visit, serviceFinishTime, element);
300+
// start and finish time use the same node, so no edge between them
301301
}
302302
}
303303

304304
if (visit.getPreviousValue() != null) {
305-
verifyAddEdge.accept(serviceFinishTime, visit.getPreviousValue(), serviceReadyTime, visit);
305+
var previousRepresentative =
306+
visit.getPreviousValue().getConcurrentValueGroup() == null ? visit.getPreviousValue()
307+
: visit.getPreviousValue().getConcurrentValueGroup().get(0);
308+
verifyAddEdge.accept(serviceFinishTime, previousRepresentative, serviceReadyTime, visit);
306309
}
307310
}
308311
// Note: addEdge only adds an edge if it does not already exists in the graph,

0 commit comments

Comments
 (0)