Skip to content

Commit df5f405

Browse files
authored
perf: revert most of 2c4715f (#1319)
The commit in question caused a 10 % regression in some benchmarks. We only keep changes that are guaranteed not to regress. The reverted changes will be re-evaluated and, if beneficial, re-submitted later.
1 parent c41ba83 commit df5f405

32 files changed

+427
-444
lines changed

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -125,35 +125,33 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, Co
125125
buildNodeNetwork(workingSolution, constraintStreamSet, scoreInliner, nodeNetworkVisualizationConsumer));
126126
}
127127

128-
@SuppressWarnings("unchecked")
129128
private static <Solution_, Score_ extends Score<Score_>> NodeNetwork buildNodeNetwork(Solution_ workingSolution,
130129
Set<BavetAbstractConstraintStream<Solution_>> constraintStreamSet, AbstractScoreInliner<Score_> scoreInliner,
131130
Consumer<String> nodeNetworkVisualizationConsumer) {
132-
/*
133-
* Build constraintStreamSet in reverse order to create downstream nodes first
134-
* so every node only has final variables (some of which have downstream node method references).
135-
*/
136131
var buildHelper = new NodeBuildHelper<>(constraintStreamSet, scoreInliner);
137-
var nodeList = buildNodeList(constraintStreamSet, buildHelper);
132+
var declaredClassToNodeMap = new LinkedHashMap<Class<?>, List<AbstractForEachUniNode<?>>>();
133+
var nodeList = buildNodeList(constraintStreamSet, buildHelper, node -> {
134+
if (!(node instanceof AbstractForEachUniNode<?> forEachUniNode)) {
135+
return;
136+
}
137+
var forEachClass = forEachUniNode.getForEachClass();
138+
var forEachUniNodeList =
139+
declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>(2));
140+
if (forEachUniNodeList.size() == 2) {
141+
// Each class can have at most two forEach nodes: one including null vars, the other excluding them.
142+
throw new IllegalStateException(
143+
"Impossible state: For class (%s) there are already 2 nodes (%s), not adding another (%s)."
144+
.formatted(forEachClass, forEachUniNodeList, forEachUniNode));
145+
}
146+
forEachUniNodeList.add(forEachUniNode);
147+
});
138148
if (nodeNetworkVisualizationConsumer != null) {
139-
var visualisation = visualizeNodeNetwork(workingSolution, buildHelper, scoreInliner, nodeList);
149+
var constraintSet = scoreInliner.getConstraints();
150+
var visualisation = NodeGraph.of(workingSolution, nodeList, constraintSet, buildHelper::getNodeCreatingStream,
151+
buildHelper::findParentNode)
152+
.buildGraphvizDOT();
140153
nodeNetworkVisualizationConsumer.accept(visualisation);
141154
}
142-
var declaredClassToNodeMap = new LinkedHashMap<Class<?>, List<AbstractForEachUniNode<Object>>>();
143-
for (var node : nodeList) {
144-
if (node instanceof AbstractForEachUniNode<?> forEachUniNode) {
145-
var forEachClass = forEachUniNode.getForEachClass();
146-
var forEachUniNodeList =
147-
declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>());
148-
if (forEachUniNodeList.size() == 2) {
149-
// Each class can have at most two forEach nodes: one including null vars, the other excluding them.
150-
throw new IllegalStateException("Impossible state: For class (" + forEachClass
151-
+ ") there are already 2 nodes (" + forEachUniNodeList + "), not adding another ("
152-
+ forEachUniNode + ").");
153-
}
154-
forEachUniNodeList.add((AbstractForEachUniNode<Object>) forEachUniNode);
155-
}
156-
}
157155
var layerMap = new TreeMap<Long, List<Propagator>>();
158156
for (var node : nodeList) {
159157
layerMap.computeIfAbsent(node.getLayerIndex(), k -> new ArrayList<>())
@@ -169,7 +167,8 @@ private static <Solution_, Score_ extends Score<Score_>> NodeNetwork buildNodeNe
169167
}
170168

171169
private static <Solution_, Score_ extends Score<Score_>> List<AbstractNode> buildNodeList(
172-
Set<BavetAbstractConstraintStream<Solution_>> constraintStreamSet, NodeBuildHelper<Score_> buildHelper) {
170+
Set<BavetAbstractConstraintStream<Solution_>> constraintStreamSet, NodeBuildHelper<Score_> buildHelper,
171+
Consumer<AbstractNode> nodeProcessor) {
173172
/*
174173
* Build constraintStreamSet in reverse order to create downstream nodes first
175174
* so every node only has final variables (some of which have downstream node method references).
@@ -188,16 +187,11 @@ private static <Solution_, Score_ extends Score<Score_>> List<AbstractNode> buil
188187
*/
189188
node.setId(nextNodeId++);
190189
node.setLayerIndex(determineLayerIndex(node, buildHelper));
190+
nodeProcessor.accept(node);
191191
}
192192
return nodeList;
193193
}
194194

195-
public static <Solution_, Score_ extends Score<Score_>> String visualizeNodeNetwork(Solution_ solution,
196-
NodeBuildHelper<Score_> buildHelper, AbstractScoreInliner<Score_> scoreInliner, List<AbstractNode> nodeList) {
197-
return NodeGraph.of(solution, buildHelper, nodeList, scoreInliner)
198-
.buildGraphvizDOT();
199-
}
200-
201195
/**
202196
* Nodes are propagated in layers.
203197
* See {@link PropagationQueue} and {@link AbstractNode} for details.

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/NodeNetwork.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* @param layeredNodes nodes grouped first by their layer, then by their index within the layer;
1818
* propagation needs to happen in this order.
1919
*/
20-
record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap, Propagator[][] layeredNodes) {
20+
record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<?>>> declaredClassToNodeMap, Propagator[][] layeredNodes) {
2121

2222
public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]);
2323

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/IndexedIfExistsBiNode.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
package ai.timefold.solver.core.impl.score.stream.bavet.bi;
22

3+
import java.util.function.BiFunction;
4+
import java.util.function.Function;
5+
36
import ai.timefold.solver.core.api.function.TriPredicate;
47
import ai.timefold.solver.core.impl.score.stream.bavet.common.AbstractIndexedIfExistsNode;
58
import ai.timefold.solver.core.impl.score.stream.bavet.common.ExistsCounter;
9+
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties;
610
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer;
7-
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexerFactory.BiMapping;
8-
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexerFactory.UniMapping;
911
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.BiTuple;
1012
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.TupleLifecycle;
1113
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.UniTuple;
1214

1315
final class IndexedIfExistsBiNode<A, B, C> extends AbstractIndexedIfExistsNode<BiTuple<A, B>, C> {
1416

15-
private final BiMapping<A, B> mappingAB;
17+
private final BiFunction<A, B, IndexProperties> mappingAB;
1618
private final TriPredicate<A, B, C> filtering;
1719

1820
public IndexedIfExistsBiNode(boolean shouldExist,
19-
BiMapping<A, B> mappingAB, UniMapping<C> mappingC,
21+
BiFunction<A, B, IndexProperties> mappingAB, Function<C, IndexProperties> mappingC,
2022
int inputStoreIndexLeftProperties, int inputStoreIndexLeftCounterEntry, int inputStoreIndexRightProperties,
2123
int inputStoreIndexRightEntry,
2224
TupleLifecycle<BiTuple<A, B>> nextNodesTupleLifecycle,
@@ -29,7 +31,7 @@ public IndexedIfExistsBiNode(boolean shouldExist,
2931
}
3032

3133
public IndexedIfExistsBiNode(boolean shouldExist,
32-
BiMapping<A, B> mappingAB, UniMapping<C> mappingC,
34+
BiFunction<A, B, IndexProperties> mappingAB, Function<C, IndexProperties> mappingC,
3335
int inputStoreIndexLeftProperties, int inputStoreIndexLeftCounterEntry, int inputStoreIndexLeftTrackerList,
3436
int inputStoreIndexRightProperties, int inputStoreIndexRightEntry, int inputStoreIndexRightTrackerList,
3537
TupleLifecycle<BiTuple<A, B>> nextNodesTupleLifecycle,
@@ -45,7 +47,7 @@ public IndexedIfExistsBiNode(boolean shouldExist,
4547
}
4648

4749
@Override
48-
protected Object createIndexProperties(BiTuple<A, B> leftTuple) {
50+
protected IndexProperties createIndexProperties(BiTuple<A, B> leftTuple) {
4951
return mappingAB.apply(leftTuple.factA, leftTuple.factB);
5052
}
5153

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/IndexedJoinBiNode.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
11
package ai.timefold.solver.core.impl.score.stream.bavet.bi;
22

33
import java.util.function.BiPredicate;
4+
import java.util.function.Function;
45

56
import ai.timefold.solver.core.impl.score.stream.bavet.common.AbstractIndexedJoinNode;
7+
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties;
68
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer;
7-
import ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexerFactory.UniMapping;
89
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.BiTuple;
910
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.TupleLifecycle;
1011
import ai.timefold.solver.core.impl.score.stream.bavet.common.tuple.UniTuple;
1112

1213
final class IndexedJoinBiNode<A, B> extends AbstractIndexedJoinNode<UniTuple<A>, B, BiTuple<A, B>> {
1314

14-
private final UniMapping<A> mappingA;
15+
private final Function<A, IndexProperties> mappingA;
1516
private final BiPredicate<A, B> filtering;
1617
private final int outputStoreSize;
1718

18-
public IndexedJoinBiNode(UniMapping<A> mappingA, UniMapping<B> mappingB,
19+
public IndexedJoinBiNode(Function<A, IndexProperties> mappingA, Function<B, IndexProperties> mappingB,
1920
int inputStoreIndexA, int inputStoreIndexEntryA, int inputStoreIndexOutTupleListA,
2021
int inputStoreIndexB, int inputStoreIndexEntryB, int inputStoreIndexOutTupleListB,
2122
TupleLifecycle<BiTuple<A, B>> nextNodesTupleLifecycle, BiPredicate<A, B> filtering,
22-
int outputStoreSize, int outputStoreIndexOutEntryA, int outputStoreIndexOutEntryB,
23-
Indexer<UniTuple<A>> indexerA, Indexer<UniTuple<B>> indexerB) {
23+
int outputStoreSize,
24+
int outputStoreIndexOutEntryA, int outputStoreIndexOutEntryB,
25+
Indexer<UniTuple<A>> indexerA,
26+
Indexer<UniTuple<B>> indexerB) {
2427
super(mappingB,
2528
inputStoreIndexA, inputStoreIndexEntryA, inputStoreIndexOutTupleListA,
2629
inputStoreIndexB, inputStoreIndexEntryB, inputStoreIndexOutTupleListB,
@@ -33,7 +36,7 @@ public IndexedJoinBiNode(UniMapping<A> mappingA, UniMapping<B> mappingB,
3336
}
3437

3538
@Override
36-
protected Object createIndexPropertiesLeft(UniTuple<A> leftTuple) {
39+
protected IndexProperties createIndexPropertiesLeft(UniTuple<A> leftTuple) {
3740
return mappingA.apply(leftTuple.factA);
3841
}
3942

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/AbstractIfExistsNode.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ protected final void updateUnchangedCounterLeft(ExistsCounter<LeftTuple_> counte
6060
}
6161

6262
protected void updateCounterLeft(ExistsCounter<LeftTuple_> counter) {
63-
var state = counter.state;
63+
TupleState state = counter.state;
6464
if (shouldExist ? counter.countRight > 0 : counter.countRight == 0) {
6565
// Insert or update
6666
switch (state) {
@@ -120,14 +120,14 @@ protected void decrementCounterRight(ExistsCounter<LeftTuple_> counter) {
120120

121121
protected ElementAwareList<FilteringTracker<LeftTuple_>> updateRightTrackerList(UniTuple<Right_> rightTuple) {
122122
ElementAwareList<FilteringTracker<LeftTuple_>> rightTrackerList = rightTuple.getStore(inputStoreIndexRightTrackerList);
123-
rightTrackerList.forEach(tracker -> {
124-
decrementCounterRight(tracker.counter);
125-
tracker.remove();
126-
});
123+
for (FilteringTracker<LeftTuple_> tuple : rightTrackerList) {
124+
decrementCounterRight(tuple.counter);
125+
tuple.remove();
126+
}
127127
return rightTrackerList;
128128
}
129129

130-
protected void updateCounterFromLeft(UniTuple<Right_> rightTuple, LeftTuple_ leftTuple, ExistsCounter<LeftTuple_> counter,
130+
protected void updateCounterFromLeft(LeftTuple_ leftTuple, UniTuple<Right_> rightTuple, ExistsCounter<LeftTuple_> counter,
131131
ElementAwareList<FilteringTracker<LeftTuple_>> leftTrackerList) {
132132
if (testFiltering(leftTuple, rightTuple)) {
133133
counter.countRight++;
@@ -137,12 +137,12 @@ protected void updateCounterFromLeft(UniTuple<Right_> rightTuple, LeftTuple_ lef
137137
}
138138
}
139139

140-
protected void updateCounterFromRight(ExistsCounter<LeftTuple_> counter, UniTuple<Right_> rightTuple,
140+
protected void updateCounterFromRight(UniTuple<Right_> rightTuple, ExistsCounter<LeftTuple_> counter,
141141
ElementAwareList<FilteringTracker<LeftTuple_>> rightTrackerList) {
142-
var leftTuple = counter.leftTuple;
143-
if (testFiltering(leftTuple, rightTuple)) {
142+
if (testFiltering(counter.leftTuple, rightTuple)) {
144143
incrementCounterRight(counter);
145-
ElementAwareList<FilteringTracker<LeftTuple_>> leftTrackerList = leftTuple.getStore(inputStoreIndexLeftTrackerList);
144+
ElementAwareList<FilteringTracker<LeftTuple_>> leftTrackerList =
145+
counter.leftTuple.getStore(inputStoreIndexLeftTrackerList);
146146
new FilteringTracker<>(counter, leftTrackerList, rightTrackerList);
147147
}
148148
}
@@ -173,7 +173,6 @@ public Propagator getPropagator() {
173173
}
174174

175175
protected static final class FilteringTracker<LeftTuple_ extends AbstractTuple> {
176-
177176
final ExistsCounter<LeftTuple_> counter;
178177
private final ElementAwareListEntry<FilteringTracker<LeftTuple_>> leftTrackerEntry;
179178
private final ElementAwareListEntry<FilteringTracker<LeftTuple_>> rightTrackerEntry;

0 commit comments

Comments
 (0)