Skip to content

Commit 9d9d8fb

Browse files
committed
chore: refactor constraint session logic into smaller pieces
1 parent f09a244 commit 9d9d8fb

File tree

3 files changed

+105
-46
lines changed

3 files changed

+105
-46
lines changed

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSession.java

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
package ai.timefold.solver.core.impl.score.stream.bavet;
22

3-
import java.util.Collections;
43
import java.util.IdentityHashMap;
5-
import java.util.List;
64
import java.util.Map;
75

86
import ai.timefold.solver.core.api.score.Score;
97
import ai.timefold.solver.core.api.score.constraint.ConstraintMatchTotal;
108
import ai.timefold.solver.core.api.score.constraint.Indictment;
119
import ai.timefold.solver.core.impl.score.director.stream.BavetConstraintStreamScoreDirectorFactory;
1210
import ai.timefold.solver.core.impl.score.stream.bavet.common.PropagationQueue;
13-
import ai.timefold.solver.core.impl.score.stream.bavet.common.Propagator;
1411
import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode;
1512
import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner;
1613

@@ -24,21 +21,17 @@
2421
public final class BavetConstraintSession<Score_ extends Score<Score_>> {
2522

2623
private final AbstractScoreInliner<Score_> scoreInliner;
27-
private final Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap;
28-
private final Propagator[][] layeredNodes; // First level is the layer, second determines iteration order.
24+
private final NodeNetwork nodeNetwork;
2925
private final Map<Class<?>, AbstractForEachUniNode<Object>[]> effectiveClassToNodeArrayMap;
3026

3127
BavetConstraintSession(AbstractScoreInliner<Score_> scoreInliner) {
32-
this(scoreInliner, Collections.emptyMap(), new Propagator[0][0]);
28+
this(scoreInliner, NodeNetwork.EMPTY);
3329
}
3430

35-
BavetConstraintSession(AbstractScoreInliner<Score_> scoreInliner,
36-
Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap,
37-
Propagator[][] layeredNodes) {
31+
BavetConstraintSession(AbstractScoreInliner<Score_> scoreInliner, NodeNetwork nodeNetwork) {
3832
this.scoreInliner = scoreInliner;
39-
this.declaredClassToNodeMap = declaredClassToNodeMap;
40-
this.layeredNodes = layeredNodes;
41-
this.effectiveClassToNodeArrayMap = new IdentityHashMap<>(declaredClassToNodeMap.size());
33+
this.nodeNetwork = nodeNetwork;
34+
this.effectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
4235
}
4336

4437
public void insert(Object fact) {
@@ -52,12 +45,7 @@ private AbstractForEachUniNode<Object>[] findNodes(Class<?> factClass) {
5245
// Map.computeIfAbsent() would have created lambdas on the hot path, this will not.
5346
var nodeArray = effectiveClassToNodeArrayMap.get(factClass);
5447
if (nodeArray == null) {
55-
nodeArray = declaredClassToNodeMap.entrySet()
56-
.stream()
57-
.filter(entry -> entry.getKey().isAssignableFrom(factClass))
58-
.map(Map.Entry::getValue)
59-
.flatMap(List::stream)
60-
.toArray(AbstractForEachUniNode[]::new);
48+
nodeArray = nodeNetwork.getApplicableForEachNodes(factClass);
6149
effectiveClassToNodeArrayMap.put(factClass, nodeArray);
6250
}
6351
return nodeArray;
@@ -78,31 +66,10 @@ public void retract(Object fact) {
7866
}
7967

8068
public Score_ calculateScore(int initScore) {
81-
var layerCount = layeredNodes.length;
82-
for (var layerIndex = 0; layerIndex < layerCount; layerIndex++) {
83-
calculateScoreInLayer(layerIndex);
84-
}
69+
nodeNetwork.propagate();
8570
return scoreInliner.extractScore(initScore);
8671
}
8772

88-
private void calculateScoreInLayer(int layerIndex) {
89-
var nodesInLayer = layeredNodes[layerIndex];
90-
var nodeCount = nodesInLayer.length;
91-
if (nodeCount == 1) {
92-
nodesInLayer[0].propagateEverything();
93-
} else {
94-
for (var node : nodesInLayer) {
95-
node.propagateRetracts();
96-
}
97-
for (var node : nodesInLayer) {
98-
node.propagateUpdates();
99-
}
100-
for (var node : nodesInLayer) {
101-
node.propagateInserts();
102-
}
103-
}
104-
}
105-
10673
public AbstractScoreInliner<Score_> getScoreInliner() {
10774
return scoreInliner;
10875
}

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import java.util.ArrayList;
44
import java.util.Collections;
5-
import java.util.HashMap;
65
import java.util.LinkedHashMap;
76
import java.util.LinkedHashSet;
87
import java.util.List;
98
import java.util.Objects;
9+
import java.util.Set;
1010
import java.util.TreeMap;
1111
import java.util.stream.Collectors;
1212

@@ -28,6 +28,7 @@
2828
import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode;
2929
import ai.timefold.solver.core.impl.score.stream.common.ConstraintLibrary;
3030
import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner;
31+
import ai.timefold.solver.core.impl.util.CollectionUtils;
3132

3233
import org.slf4j.Logger;
3334
import org.slf4j.LoggerFactory;
@@ -66,7 +67,7 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
6667
var scoreDefinition = solutionDescriptor.<Score_> getScoreDefinition();
6768
var zeroScore = scoreDefinition.getZeroScore();
6869
var constraintStreamSet = new LinkedHashSet<BavetAbstractConstraintStream<Solution_>>();
69-
var constraintWeightMap = new HashMap<Constraint, Score_>(constraintLibrary.getConstraints().size());
70+
var constraintWeightMap = CollectionUtils.<Constraint, Score_> newHashMap(constraintLibrary.getConstraints().size());
7071

7172
// Only log constraint weights if logging is enabled; otherwise we don't need to build the string.
7273
var constraintWeightLoggingEnabled = !scoreDirectorDerived && LOGGER.isEnabledForLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL);
@@ -118,6 +119,11 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
118119
LOGGER.atLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL)
119120
.log(constraintWeightString.toString().trim());
120121
}
122+
return new BavetConstraintSession<>(scoreInliner, buildNodeNetwork(constraintStreamSet, scoreInliner));
123+
}
124+
125+
private static <Solution_, Score_ extends Score<Score_>> NodeNetwork buildNodeNetwork(
126+
Set<BavetAbstractConstraintStream<Solution_>> constraintStreamSet, AbstractScoreInliner<Score_> scoreInliner) {
121127
/*
122128
* Build constraintStreamSet in reverse order to create downstream nodes first
123129
* so every node only has final variables (some of which have downstream node method references).
@@ -162,7 +168,7 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
162168
var layer = layerMap.get((long) i);
163169
layeredNodes[i] = layer.toArray(new Propagator[0]);
164170
}
165-
return new BavetConstraintSession<>(scoreInliner, declaredClassToNodeMap, layeredNodes);
171+
return new NodeNetwork(declaredClassToNodeMap, layeredNodes);
166172
}
167173

168174
/**
@@ -180,7 +186,8 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
180186
* @param buildHelper never null
181187
* @return at least 0
182188
*/
183-
private long determineLayerIndex(AbstractNode node, NodeBuildHelper<Score_> buildHelper) {
189+
private static <Score_ extends Score<Score_>> long determineLayerIndex(AbstractNode node,
190+
NodeBuildHelper<Score_> buildHelper) {
184191
if (node instanceof AbstractForEachUniNode<?>) { // ForEach nodes, and only they, are in layer 0.
185192
return 0;
186193
} else if (node instanceof AbstractJoinNode<?, ?, ?> joinNode) {
@@ -199,8 +206,8 @@ private long determineLayerIndex(AbstractNode node, NodeBuildHelper<Score_> buil
199206
}
200207
}
201208

202-
private long determineLayerIndexOfBinaryOperation(BavetStreamBinaryOperation<?> nodeCreator,
203-
NodeBuildHelper<Score_> buildHelper) {
209+
private static <Score_ extends Score<Score_>> long determineLayerIndexOfBinaryOperation(
210+
BavetStreamBinaryOperation<?> nodeCreator, NodeBuildHelper<Score_> buildHelper) {
204211
var leftParent = nodeCreator.getLeftParent();
205212
var rightParent = nodeCreator.getRightParent();
206213
var leftParentNode = buildHelper.findParentNode(leftParent);
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package ai.timefold.solver.core.impl.score.stream.bavet;
2+
3+
import java.util.Arrays;
4+
import java.util.List;
5+
import java.util.Map;
6+
import java.util.Objects;
7+
8+
import ai.timefold.solver.core.impl.score.stream.bavet.common.Propagator;
9+
import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode;
10+
11+
/**
12+
* Represents Bavet's network of nodes, specific to a particular session.
13+
* Nodes only used by disabled constraints have already been removed.
14+
*
15+
* @param declaredClassToNodeMap starting nodes, one for each class used in the constraints;
16+
* root nodes, layer index 0.
17+
* @param layeredNodes nodes grouped first by their layer, then by their index within the layer;
18+
* propagation needs to happen in this order.
19+
*/
20+
record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap, Propagator[][] layeredNodes) {
21+
22+
public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]);
23+
24+
public int forEachNodeCount() {
25+
return declaredClassToNodeMap.size();
26+
}
27+
28+
public int layerCount() {
29+
return layeredNodes.length;
30+
}
31+
32+
@SuppressWarnings("unchecked")
33+
public AbstractForEachUniNode<Object>[] getApplicableForEachNodes(Class<?> factClass) {
34+
return declaredClassToNodeMap.entrySet()
35+
.stream()
36+
.filter(entry -> entry.getKey().isAssignableFrom(factClass))
37+
.map(Map.Entry::getValue)
38+
.flatMap(List::stream)
39+
.toArray(AbstractForEachUniNode[]::new);
40+
}
41+
42+
public void propagate() {
43+
for (var layerIndex = 0; layerIndex < layerCount(); layerIndex++) {
44+
propagateInLayer(layeredNodes[layerIndex]);
45+
}
46+
}
47+
48+
private static void propagateInLayer(Propagator[] nodesInLayer) {
49+
var nodeCount = nodesInLayer.length;
50+
if (nodeCount == 1) {
51+
nodesInLayer[0].propagateEverything();
52+
} else {
53+
for (var node : nodesInLayer) {
54+
node.propagateRetracts();
55+
}
56+
for (var node : nodesInLayer) {
57+
node.propagateUpdates();
58+
}
59+
for (var node : nodesInLayer) {
60+
node.propagateInserts();
61+
}
62+
}
63+
}
64+
65+
@Override
66+
public boolean equals(Object o) {
67+
if (this == o)
68+
return true;
69+
if (!(o instanceof NodeNetwork that))
70+
return false;
71+
return Objects.equals(declaredClassToNodeMap, that.declaredClassToNodeMap)
72+
&& Objects.deepEquals(layeredNodes, that.layeredNodes);
73+
}
74+
75+
@Override
76+
public int hashCode() {
77+
return Objects.hash(declaredClassToNodeMap, Arrays.deepHashCode(layeredNodes));
78+
}
79+
80+
@Override
81+
public String toString() {
82+
return this.getClass().getSimpleName() + " with " + forEachNodeCount() + " forEach nodes.";
83+
}
84+
85+
}

0 commit comments

Comments
 (0)