Skip to content

Commit ee2410f

Browse files
perf: simplify handling of @ShadowVariableLooped (#1726)
- @ShadowVariableLooped is now updated before the calculator is called - Nodes in the graph now have entityId and groupEntityIds - A mapping is created from entityId to graphNodeId, which LoopedTracker will used to lookup the nodes corresponding to a given entityId.
1 parent 0be77ae commit ee2410f

File tree

14 files changed

+265
-141
lines changed

14 files changed

+265
-141
lines changed

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/AbstractVariableReferenceGraph.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ public abstract sealed class AbstractVariableReferenceGraph<Solution_, ChangeSet
2020
permits DefaultVariableReferenceGraph, FixedVariableReferenceGraph {
2121

2222
// These structures are immutable.
23-
protected final List<EntityVariablePair<Solution_>> instanceList;
24-
protected final Map<VariableMetaModel<?, ?, ?>, Map<Object, EntityVariablePair<Solution_>>> variableReferenceToInstanceMap;
23+
protected final List<GraphNode<Solution_>> nodeList;
24+
protected final Map<VariableMetaModel<?, ?, ?>, Map<Object, GraphNode<Solution_>>> variableReferenceToContainingNodeMap;
2525
protected final Map<VariableMetaModel<?, ?, ?>, List<BiConsumer<AbstractVariableReferenceGraph<Solution_, ?>, Object>>> variableReferenceToBeforeProcessor;
2626
protected final Map<VariableMetaModel<?, ?, ?>, List<BiConsumer<AbstractVariableReferenceGraph<Solution_, ?>, Object>>> variableReferenceToAfterProcessor;
2727

@@ -32,22 +32,22 @@ public abstract sealed class AbstractVariableReferenceGraph<Solution_, ChangeSet
3232

3333
AbstractVariableReferenceGraph(VariableReferenceGraphBuilder<Solution_> outerGraph,
3434
IntFunction<TopologicalOrderGraph> graphCreator) {
35-
instanceList = List.copyOf(outerGraph.instanceList);
36-
var instanceCount = instanceList.size();
35+
nodeList = List.copyOf(outerGraph.nodeList);
36+
var instanceCount = nodeList.size();
3737
// Often the maps are a singleton; we improve performance by actually making it so.
38-
variableReferenceToInstanceMap = mapOfMapsDeepCopyOf(outerGraph.variableReferenceToInstanceMap);
38+
variableReferenceToContainingNodeMap = mapOfMapsDeepCopyOf(outerGraph.variableReferenceToContainingNodeMap);
3939
variableReferenceToBeforeProcessor = mapOfListsDeepCopyOf(outerGraph.variableReferenceToBeforeProcessor);
4040
variableReferenceToAfterProcessor = mapOfListsDeepCopyOf(outerGraph.variableReferenceToAfterProcessor);
4141
edgeCount = new DynamicIntArray[instanceCount];
4242
for (int i = 0; i < instanceCount; i++) {
4343
edgeCount[i] = new DynamicIntArray(instanceCount);
4444
}
4545
graph = graphCreator.apply(instanceCount);
46-
graph.withNodeData(instanceList);
46+
graph.withNodeData(nodeList);
4747

4848
var visited = Collections.newSetFromMap(new IdentityHashMap<>());
4949
changeSet = createChangeSet(instanceCount);
50-
for (var instance : instanceList) {
50+
for (var instance : nodeList) {
5151
var entity = instance.entity();
5252
if (visited.add(entity)) {
5353
for (var variableId : outerGraph.variableReferenceToAfterProcessor.keySet()) {
@@ -64,15 +64,15 @@ public abstract sealed class AbstractVariableReferenceGraph<Solution_, ChangeSet
6464

6565
protected abstract ChangeSet_ createChangeSet(int instanceCount);
6666

67-
public @Nullable EntityVariablePair<Solution_> lookupOrNull(VariableMetaModel<?, ?, ?> variableId, Object entity) {
68-
var map = variableReferenceToInstanceMap.get(variableId);
67+
public @Nullable GraphNode<Solution_> lookupOrNull(VariableMetaModel<?, ?, ?> variableId, Object entity) {
68+
var map = variableReferenceToContainingNodeMap.get(variableId);
6969
if (map == null) {
7070
return null;
7171
}
7272
return map.get(entity);
7373
}
7474

75-
public void addEdge(@NonNull EntityVariablePair<Solution_> from, @NonNull EntityVariablePair<Solution_> to) {
75+
public void addEdge(@NonNull GraphNode<Solution_> from, @NonNull GraphNode<Solution_> to) {
7676
var fromNodeId = from.graphNodeId();
7777
var toNodeId = to.graphNodeId();
7878
if (fromNodeId == toNodeId) {
@@ -87,7 +87,7 @@ public void addEdge(@NonNull EntityVariablePair<Solution_> from, @NonNull Entity
8787
markChanged(to);
8888
}
8989

90-
public void removeEdge(@NonNull EntityVariablePair<Solution_> from, @NonNull EntityVariablePair<Solution_> to) {
90+
public void removeEdge(@NonNull GraphNode<Solution_> from, @NonNull GraphNode<Solution_> to) {
9191
var fromNodeId = from.graphNodeId();
9292
var toNodeId = to.graphNodeId();
9393
if (fromNodeId == toNodeId) {
@@ -102,7 +102,7 @@ public void removeEdge(@NonNull EntityVariablePair<Solution_> from, @NonNull Ent
102102
markChanged(to);
103103
}
104104

105-
abstract void markChanged(EntityVariablePair<Solution_> changed);
105+
abstract void markChanged(GraphNode<Solution_> changed);
106106

107107
@Override
108108
public void beforeVariableChanged(VariableMetaModel<?, ?, ?> variableReference, Object entity) {
@@ -135,9 +135,9 @@ public void afterVariableChanged(VariableMetaModel<?, ?, ?> variableReference, O
135135

136136
@Override
137137
public String toString() {
138-
var edgeList = new LinkedHashMap<EntityVariablePair<Solution_>, List<EntityVariablePair<Solution_>>>();
139-
graph.forEachEdge((from, to) -> edgeList.computeIfAbsent(instanceList.get(from), k -> new ArrayList<>())
140-
.add(instanceList.get(to)));
138+
var edgeList = new LinkedHashMap<GraphNode<Solution_>, List<GraphNode<Solution_>>>();
139+
graph.forEachEdge((from, to) -> edgeList.computeIfAbsent(nodeList.get(from), k -> new ArrayList<>())
140+
.add(nodeList.get(to)));
141141
return edgeList.entrySet()
142142
.stream()
143143
.map(e -> e.getKey() + "->" + e.getValue())
Lines changed: 74 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,76 @@
11
package ai.timefold.solver.core.impl.domain.variable.declarative;
22

3-
import java.util.Arrays;
43
import java.util.BitSet;
4+
import java.util.IdentityHashMap;
55
import java.util.List;
66
import java.util.Objects;
77
import java.util.PriorityQueue;
8-
import java.util.Set;
98
import java.util.function.Consumer;
109
import java.util.function.Function;
10+
import java.util.stream.Collectors;
1111

1212
import ai.timefold.solver.core.impl.domain.variable.descriptor.VariableDescriptor;
13-
import ai.timefold.solver.core.impl.util.LinkedIdentityHashSet;
1413

1514
final class AffectedEntitiesUpdater<Solution_>
1615
implements Consumer<BitSet> {
1716

1817
// From WorkingReferenceGraph.
1918
private final BaseTopologicalOrderGraph graph;
20-
private final List<EntityVariablePair<Solution_>> instanceList; // Immutable.
21-
private final Function<Object, List<EntityVariablePair<Solution_>>> entityVariablePairFunction;
19+
private final List<GraphNode<Solution_>> nodeList; // Immutable.
2220
private final ChangedVariableNotifier<Solution_> changedVariableNotifier;
2321

2422
// Internal state; expensive to create, therefore we reuse.
25-
private final AffectedEntities<Solution_> affectedEntities;
2623
private final LoopedTracker loopedTracker;
2724
private final BitSet visited;
2825
private final PriorityQueue<BaseTopologicalOrderGraph.NodeTopologicalOrder> changeQueue;
2926

30-
AffectedEntitiesUpdater(BaseTopologicalOrderGraph graph, List<EntityVariablePair<Solution_>> instanceList,
31-
Function<Object, List<EntityVariablePair<Solution_>>> entityVariablePairFunction,
32-
ChangedVariableNotifier<Solution_> changedVariableNotifier) {
27+
AffectedEntitiesUpdater(BaseTopologicalOrderGraph graph, List<GraphNode<Solution_>> nodeList,
28+
Function<Object, List<GraphNode<Solution_>>> entityToContainingNode,
29+
int entityCount, ChangedVariableNotifier<Solution_> changedVariableNotifier) {
3330
this.graph = graph;
34-
this.instanceList = instanceList;
35-
this.entityVariablePairFunction = entityVariablePairFunction;
31+
this.nodeList = nodeList;
3632
this.changedVariableNotifier = changedVariableNotifier;
37-
var instanceCount = instanceList.size();
38-
this.affectedEntities = new AffectedEntities<>(this::updateLoopedStatusOfAffectedEntity);
39-
this.loopedTracker = new LoopedTracker(instanceCount);
33+
var instanceCount = nodeList.size();
34+
this.loopedTracker = new LoopedTracker(instanceCount,
35+
createNodeToEntityNodes(entityCount, nodeList, entityToContainingNode));
4036
this.visited = new BitSet(instanceCount);
4137
this.changeQueue = new PriorityQueue<>(instanceCount);
4238
}
4339

40+
static <Solution_> int[][] createNodeToEntityNodes(int entityCount,
41+
List<GraphNode<Solution_>> nodeList,
42+
Function<Object, List<GraphNode<Solution_>>> entityToContainingNode) {
43+
record EntityIdPair(Object entity, int entityId) {
44+
@Override
45+
public boolean equals(Object o) {
46+
if (!(o instanceof EntityIdPair that))
47+
return false;
48+
return entityId == that.entityId;
49+
}
50+
51+
@Override
52+
public int hashCode() {
53+
return Objects.hashCode(entityId);
54+
}
55+
}
56+
int[][] out = new int[entityCount][];
57+
var entityToNodes = new IdentityHashMap<Integer, int[]>();
58+
var entityIdPairSet = nodeList.stream()
59+
.map(node -> new EntityIdPair(node.entity(), node.entityId()))
60+
.collect(Collectors.toSet());
61+
for (var entityIdPair : entityIdPairSet) {
62+
entityToNodes.put(entityIdPair.entityId(),
63+
entityToContainingNode.apply(entityIdPair.entity).stream().mapToInt(GraphNode::graphNodeId)
64+
.toArray());
65+
}
66+
67+
for (var entry : entityToNodes.entrySet()) {
68+
out[entry.getKey()] = entry.getValue();
69+
}
70+
71+
return out;
72+
}
73+
4474
@Override
4575
public void accept(BitSet changed) {
4676
initializeChangeQueue(changed);
@@ -51,7 +81,7 @@ public void accept(BitSet changed) {
5181
continue;
5282
}
5383
visited.set(nextNode);
54-
var shadowVariable = instanceList.get(nextNode);
84+
var shadowVariable = nodeList.get(nextNode);
5585
var isChanged = updateEntityShadowVariables(shadowVariable, graph.isLooped(loopedTracker, nextNode));
5686

5787
if (isChanged) {
@@ -65,7 +95,6 @@ public void accept(BitSet changed) {
6595
}
6696
}
6797

68-
affectedEntities.processAndClear();
6998
// Prepare for the next time updateChanged() is called.
7099
// No need to clear changeQueue, as that already finishes empty.
71100
loopedTracker.clear();
@@ -91,52 +120,50 @@ private void initializeChangeQueue(BitSet changed) {
91120
changed.clear();
92121
}
93122

94-
private void updateLoopedStatusOfAffectedEntity(Object affectedEntity) {
95-
ShadowVariableLoopedVariableDescriptor<Solution_> shadowVariableLoopedDescriptor = null;
96-
var isEntityLooped = false;
97-
for (var node : entityVariablePairFunction.apply(affectedEntity)) {
98-
// All variables come from the same entity,
99-
// therefore all have the same looped marker.
100-
shadowVariableLoopedDescriptor = node.variableReferences().get(0).shadowVariableLoopedDescriptor();
101-
if (graph.isLooped(loopedTracker, node.graphNodeId())) {
102-
isEntityLooped = true;
103-
break;
104-
}
105-
}
106-
if (shadowVariableLoopedDescriptor == null) {
107-
// At this point, affectedEntity is guaranteed to have looped marker.
108-
// Otherwise AffectedEntities would not have sent it here.
109-
throw new IllegalStateException("Impossible state: loop marker descriptor does not exist.");
110-
}
111-
var oldValue = shadowVariableLoopedDescriptor.getValue(affectedEntity);
112-
if (!Objects.equals(oldValue, isEntityLooped)) {
113-
changeShadowVariableAndNotify(shadowVariableLoopedDescriptor, affectedEntity, isEntityLooped);
114-
}
115-
116-
}
117-
118-
private boolean updateEntityShadowVariables(EntityVariablePair<Solution_> entityVariable, boolean isLooped) {
123+
private boolean updateEntityShadowVariables(GraphNode<Solution_> entityVariable, boolean isVariableLooped) {
119124
var entity = entityVariable.entity();
120125
var shadowVariableReferences = entityVariable.variableReferences();
121126
var loopDescriptor = shadowVariableReferences.get(0).shadowVariableLoopedDescriptor();
122127
var anyChanged = false;
123128

124129
if (loopDescriptor != null) {
125-
var oldLooped = loopDescriptor.getValue(entity);
126-
if (!Objects.equals(oldLooped, isLooped)) {
127-
// Loop status change; add to affected entities
128-
affectedEntities.add(entityVariable);
129-
anyChanged = true;
130+
// Do not need to update anyChanged here; the graph already marked
131+
// all nodes whose looped status changed for us
132+
var groupEntities = shadowVariableReferences.get(0).groupEntities();
133+
var groupEntityIds = entityVariable.groupEntityIds();
134+
135+
if (groupEntities != null) {
136+
for (var i = 0; i < groupEntityIds.length; i++) {
137+
var groupEntity = groupEntities[i];
138+
var groupEntityId = groupEntityIds[i];
139+
anyChanged |= updateLoopedStatusOfEntity(groupEntity, groupEntityId, loopDescriptor);
140+
}
141+
} else {
142+
anyChanged |= updateLoopedStatusOfEntity(entity, entityVariable.entityId(), loopDescriptor);
130143
}
131144
}
132145

133146
for (var shadowVariableReference : shadowVariableReferences) {
134-
anyChanged |= updateShadowVariable(isLooped, shadowVariableReference, entity);
147+
anyChanged |= updateShadowVariable(isVariableLooped, shadowVariableReference, entity);
135148
}
136149

137150
return anyChanged;
138151
}
139152

153+
private boolean updateLoopedStatusOfEntity(Object entity, int entityId,
154+
ShadowVariableLoopedVariableDescriptor<Solution_> loopDescriptor) {
155+
var oldLooped = (boolean) loopDescriptor.getValue(entity);
156+
var isEntityLooped = loopedTracker.isEntityLooped(graph, entityId, oldLooped);
157+
if (!Objects.equals(oldLooped, isEntityLooped)) {
158+
changeShadowVariableAndNotify(loopDescriptor, entity, isEntityLooped);
159+
}
160+
// We return true if the entity's loop status changed at any point;
161+
// Since an entity might correspond to multiple nodes, we want all nodes
162+
// for that entity to be marked as changed, not just the first node the
163+
// updater encounters
164+
return loopedTracker.didEntityLoopedStatusChange(entityId);
165+
}
166+
140167
private boolean updateShadowVariable(boolean isLooped,
141168
VariableUpdaterInfo<Solution_> shadowVariableReference, Object entity) {
142169
if (isLooped) {
@@ -152,37 +179,4 @@ private void changeShadowVariableAndNotify(VariableDescriptor<Solution_> variabl
152179
variableDescriptor.setValue(entity, newValue);
153180
changedVariableNotifier.afterVariableChanged().accept(variableDescriptor, entity);
154181
}
155-
156-
private static final class AffectedEntities<Solution_> {
157-
158-
private final Consumer<Object> consumer;
159-
private final Set<Object> entitiesForLoopedVarUpdateSet;
160-
161-
public AffectedEntities(Consumer<Object> consumer) {
162-
this.consumer = consumer;
163-
this.entitiesForLoopedVarUpdateSet = new LinkedIdentityHashSet<>();
164-
}
165-
166-
public void add(EntityVariablePair<Solution_> shadowVariable) {
167-
var shadowVariableLoopedDescriptor = shadowVariable.variableReferences().get(0).shadowVariableLoopedDescriptor();
168-
if (shadowVariableLoopedDescriptor == null) {
169-
return;
170-
}
171-
var entityGroup = shadowVariable.variableReferences().get(0).groupEntities();
172-
if (entityGroup == null) {
173-
entitiesForLoopedVarUpdateSet.add(shadowVariable.entity());
174-
} else {
175-
entitiesForLoopedVarUpdateSet.addAll(Arrays.asList(entityGroup));
176-
}
177-
}
178-
179-
public void processAndClear() {
180-
for (var entity : entitiesForLoopedVarUpdateSet) {
181-
consumer.accept(entity);
182-
}
183-
entitiesForLoopedVarUpdateSet.clear();
184-
}
185-
186-
}
187-
188182
}

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DeclarativeShadowVariableDescriptor.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import ai.timefold.solver.core.preview.api.domain.metamodel.PlanningSolutionMetaModel;
2424
import ai.timefold.solver.core.preview.api.domain.variable.declarative.ShadowSources;
2525

26+
import org.jspecify.annotations.Nullable;
27+
2628
public class DeclarativeShadowVariableDescriptor<Solution_> extends ShadowVariableDescriptor<Solution_> {
2729
MemberAccessor calculator;
2830
RootVariableSource<?, ?>[] sources;
@@ -188,6 +190,10 @@ public Function<Object, Object> getAlignmentKeyMap() {
188190
return alignmentKeyMap;
189191
}
190192

193+
public @Nullable String getAlignmentKeyName() {
194+
return alignmentKey;
195+
}
196+
191197
public RootVariableSource<?, ?>[] getSources() {
192198
return sources;
193199
}

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultVariableReferenceGraph.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ public DefaultVariableReferenceGraph(VariableReferenceGraphBuilder<Solution_> ou
1717
IntFunction<TopologicalOrderGraph> graphCreator) {
1818
super(outerGraph, graphCreator);
1919

20-
var entityToVariableReferenceMap = new IdentityHashMap<Object, List<EntityVariablePair<Solution_>>>();
21-
for (var instance : instanceList) {
20+
var entityToVariableReferenceMap = new IdentityHashMap<Object, List<GraphNode<Solution_>>>();
21+
for (var instance : nodeList) {
2222
var entity = instance.entity();
2323
entityToVariableReferenceMap.computeIfAbsent(entity, ignored -> new ArrayList<>())
2424
.add(instance);
@@ -28,8 +28,9 @@ public DefaultVariableReferenceGraph(VariableReferenceGraphBuilder<Solution_> ou
2828
// This mutable structure is created once, and reused from there on.
2929
// Otherwise its internal collections were observed being re-created so often
3030
// that the allocation of arrays would become a major bottleneck.
31-
affectedEntitiesUpdater = new AffectedEntitiesUpdater<>(graph, instanceList, immutableEntityToVariableReferenceMap::get,
32-
outerGraph.changedVariableNotifier);
31+
affectedEntitiesUpdater =
32+
new AffectedEntitiesUpdater<>(graph, nodeList, immutableEntityToVariableReferenceMap::get,
33+
outerGraph.entityToEntityId.size(), outerGraph.changedVariableNotifier);
3334
}
3435

3536
@Override
@@ -38,7 +39,7 @@ protected BitSet createChangeSet(int instanceCount) {
3839
}
3940

4041
@Override
41-
public void markChanged(@NonNull EntityVariablePair<Solution_> node) {
42+
public void markChanged(@NonNull GraphNode<Solution_> node) {
4243
changeSet.set(node.graphNodeId());
4344
}
4445

0 commit comments

Comments
 (0)