Skip to content

Commit 3616a02

Browse files
committed
Implement distinct()
1 parent 84efd57 commit 3616a02

File tree

10 files changed

+239
-61
lines changed

10 files changed

+239
-61
lines changed

core/src/main/java/ai/timefold/solver/core/impl/localsearch/DefaultLocalSearchPhaseFactory.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
package ai.timefold.solver.core.impl.localsearch;
22

3-
import java.util.ArrayList;
4-
import java.util.Collections;
5-
import java.util.List;
6-
import java.util.Objects;
7-
83
import ai.timefold.solver.core.api.domain.entity.PinningFilter;
94
import ai.timefold.solver.core.api.domain.entity.PlanningPin;
105
import ai.timefold.solver.core.api.domain.variable.PlanningListVariable;
@@ -50,6 +45,11 @@
5045
import ai.timefold.solver.core.impl.solver.termination.SolverTermination;
5146
import ai.timefold.solver.core.preview.api.domain.metamodel.PlanningVariableMetaModel;
5247

48+
import java.util.ArrayList;
49+
import java.util.Collections;
50+
import java.util.List;
51+
import java.util.Objects;
52+
5353
public class DefaultLocalSearchPhaseFactory<Solution_> extends AbstractPhaseFactory<Solution_, LocalSearchPhaseConfig> {
5454

5555
public DefaultLocalSearchPhaseFactory(LocalSearchPhaseConfig phaseConfig) {
@@ -134,7 +134,7 @@ Convert your entities (%s) to use @%s instead."""
134134
.formatted(moveProvidersClass, moveProviderList.size()));
135135
}
136136
var moveProvider = moveProviderList.get(0);
137-
var moveStreamFactory = new DefaultMoveStreamFactory<>(solutionDescriptor);
137+
var moveStreamFactory = new DefaultMoveStreamFactory<>(solutionDescriptor, configPolicy.getEnvironmentMode());
138138
var moveProducer = moveProvider.apply(moveStreamFactory);
139139
var moveRepository = new MoveStreamsBasedMoveRepository<>(moveStreamFactory, moveProducer,
140140
pickSelectionOrder() == SelectionOrder.RANDOM);

core/src/main/java/ai/timefold/solver/core/impl/move/streams/DefaultMoveStreamFactory.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package ai.timefold.solver.core.impl.move.streams;
22

3+
import ai.timefold.solver.core.config.solver.EnvironmentMode;
34
import ai.timefold.solver.core.impl.domain.solution.descriptor.DefaultPlanningListVariableMetaModel;
45
import ai.timefold.solver.core.impl.domain.solution.descriptor.DefaultPlanningVariableMetaModel;
56
import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor;
@@ -28,8 +29,8 @@ public final class DefaultMoveStreamFactory<Solution_>
2829
private final DataStreamFactory<Solution_> dataStreamFactory;
2930
private final DatasetSessionFactory<Solution_> datasetSessionFactory;
3031

31-
public DefaultMoveStreamFactory(SolutionDescriptor<Solution_> solutionDescriptor) {
32-
this.dataStreamFactory = new DataStreamFactory<>(solutionDescriptor);
32+
public DefaultMoveStreamFactory(SolutionDescriptor<Solution_> solutionDescriptor, EnvironmentMode environmentMode) {
33+
this.dataStreamFactory = new DataStreamFactory<>(solutionDescriptor, environmentMode);
3334
this.datasetSessionFactory = new DatasetSessionFactory<>(dataStreamFactory);
3435
}
3536

core/src/main/java/ai/timefold/solver/core/impl/move/streams/dataset/DataStreamFactory.java

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
package ai.timefold.solver.core.impl.move.streams.dataset;
22

3-
import java.util.HashMap;
4-
import java.util.List;
5-
import java.util.Map;
6-
import java.util.Objects;
7-
import java.util.function.Consumer;
8-
import java.util.stream.Collectors;
9-
import java.util.stream.Stream;
10-
3+
import ai.timefold.solver.core.config.solver.EnvironmentMode;
114
import ai.timefold.solver.core.impl.domain.solution.descriptor.InnerGenuineVariableMetaModel;
125
import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor;
136
import ai.timefold.solver.core.impl.move.streams.dataset.common.AbstractDataStream;
@@ -20,17 +13,26 @@
2013
import ai.timefold.solver.core.impl.move.streams.maybeapi.stream.UniDataStream;
2114
import ai.timefold.solver.core.impl.score.director.SessionContext;
2215
import ai.timefold.solver.core.preview.api.domain.metamodel.GenuineVariableMetaModel;
23-
2416
import org.jspecify.annotations.NullMarked;
2517

18+
import java.util.HashMap;
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.Objects;
22+
import java.util.function.Consumer;
23+
import java.util.stream.Collectors;
24+
import java.util.stream.Stream;
25+
2626
@NullMarked
2727
public final class DataStreamFactory<Solution_> {
2828

2929
private final SolutionDescriptor<Solution_> solutionDescriptor;
30+
private final EnvironmentMode environmentMode;
3031
private final Map<AbstractDataStream<Solution_>, AbstractDataStream<Solution_>> sharingStreamMap = new HashMap<>(256);
3132

32-
public DataStreamFactory(SolutionDescriptor<Solution_> solutionDescriptor) {
33+
public DataStreamFactory(SolutionDescriptor<Solution_> solutionDescriptor, EnvironmentMode environmentMode) {
3334
this.solutionDescriptor = Objects.requireNonNull(solutionDescriptor);
35+
this.environmentMode = Objects.requireNonNull(environmentMode);
3436
}
3537

3638
public <A> UniDataStream<Solution_, A> forEachNonDiscriminating(Class<A> sourceClass, boolean includeNull) {
@@ -131,6 +133,10 @@ public SolutionDescriptor<Solution_> getSolutionDescriptor() {
131133
return solutionDescriptor;
132134
}
133135

136+
public EnvironmentMode getEnvironmentMode() {
137+
return environmentMode;
138+
}
139+
134140
@SuppressWarnings("unchecked")
135141
public List<AbstractDataset<Solution_, ?>> getDatasets() {
136142
return sharingStreamMap.values().stream()

core/src/main/java/ai/timefold/solver/core/impl/move/streams/dataset/bi/AbstractBiDataStream.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package ai.timefold.solver.core.impl.move.streams.dataset.bi;
22

3+
import ai.timefold.solver.core.impl.bavet.bi.Group2Mapping0CollectorBiNode;
4+
import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor;
5+
import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple;
36
import ai.timefold.solver.core.impl.move.streams.dataset.DataStreamFactory;
47
import ai.timefold.solver.core.impl.move.streams.dataset.common.AbstractDataStream;
58
import ai.timefold.solver.core.impl.move.streams.dataset.common.bridge.AftBridgeBiDataStream;
@@ -8,10 +11,12 @@
811
import ai.timefold.solver.core.impl.move.streams.maybeapi.BiDataMapper;
912
import ai.timefold.solver.core.impl.move.streams.maybeapi.stream.BiDataStream;
1013
import ai.timefold.solver.core.impl.move.streams.maybeapi.stream.UniDataStream;
11-
14+
import ai.timefold.solver.core.impl.util.ConstantLambdaUtils;
1215
import org.jspecify.annotations.NullMarked;
1316
import org.jspecify.annotations.Nullable;
1417

18+
import java.util.function.BiFunction;
19+
1520
@NullMarked
1621
public abstract class AbstractBiDataStream<Solution_, A, B> extends AbstractDataStream<Solution_>
1722
implements BiDataStream<Solution_, A, B> {
@@ -30,6 +35,18 @@ public final BiDataStream<Solution_, A, B> filter(BiDataFilter<Solution_, A, B>
3035
return shareAndAddChild(new FilterBiDataStream<>(dataStreamFactory, this, filter));
3136
}
3237

38+
39+
protected <GroupKeyA_, GroupKeyB_> AbstractBiDataStream<Solution_, GroupKeyA_, GroupKeyB_> groupBy(BiFunction<A, B, GroupKeyA_> groupKeyAMapping, BiFunction<A, B, GroupKeyB_> groupKeyBMapping) {
40+
GroupNodeConstructor<BiTuple<GroupKeyA_, GroupKeyB_>> nodeConstructor =
41+
GroupNodeConstructor.twoKeysGroupBy(groupKeyAMapping, groupKeyBMapping, Group2Mapping0CollectorBiNode::new);
42+
return buildBiGroupBy(nodeConstructor);
43+
}
44+
45+
private <NewA, NewB> AbstractBiDataStream<Solution_, NewA, NewB> buildBiGroupBy(GroupNodeConstructor<BiTuple<NewA, NewB>> nodeConstructor) {
46+
var stream = shareAndAddChild(new BiGroupBiDataStream<>(dataStreamFactory, this, nodeConstructor));
47+
return dataStreamFactory.share(new AftBridgeBiDataStream<>(dataStreamFactory, stream), stream::setAftBridge);
48+
}
49+
3350
@Override
3451
public <ResultA_> UniDataStream<Solution_, ResultA_> map(BiDataMapper<Solution_, A, B, ResultA_> mapping) {
3552
var stream = shareAndAddChild(new UniMapBiDataStream<>(dataStreamFactory, this, mapping));
@@ -43,11 +60,11 @@ public <ResultA_, ResultB_> BiDataStream<Solution_, ResultA_, ResultB_> map(BiDa
4360
}
4461

4562
@Override
46-
public BiDataStream<Solution_, A, B> distinct() {
63+
public AbstractBiDataStream<Solution_, A, B> distinct() {
4764
if (guaranteesDistinct()) {
4865
return this; // Already distinct, no need to create a new stream.
4966
}
50-
throw new UnsupportedOperationException();
67+
return groupBy(ConstantLambdaUtils.biPickFirst(), ConstantLambdaUtils.biPickSecond());
5168
}
5269

5370
public BiDataset<Solution_, A, B> createDataset() {
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package ai.timefold.solver.core.impl.move.streams.dataset.bi;
2+
3+
import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor;
4+
import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple;
5+
import ai.timefold.solver.core.impl.move.streams.dataset.DataStreamFactory;
6+
import ai.timefold.solver.core.impl.move.streams.dataset.common.DataNodeBuildHelper;
7+
import ai.timefold.solver.core.impl.move.streams.dataset.common.bridge.AftBridgeBiDataStream;
8+
import org.jspecify.annotations.NullMarked;
9+
import org.jspecify.annotations.Nullable;
10+
11+
import java.util.Objects;
12+
13+
@NullMarked
14+
final class BiGroupBiDataStream<Solution_, A, B, NewA, NewB>
15+
extends AbstractBiDataStream<Solution_, A, B> {
16+
17+
private final GroupNodeConstructor<BiTuple<NewA, NewB>> nodeConstructor;
18+
private @Nullable AftBridgeBiDataStream<Solution_, NewA, NewB> aftStream;
19+
20+
public BiGroupBiDataStream(DataStreamFactory<Solution_> dataStreamFactory, AbstractBiDataStream<Solution_, A, B> parent,
21+
GroupNodeConstructor<BiTuple<NewA, NewB>> nodeConstructor) {
22+
super(dataStreamFactory, parent);
23+
this.nodeConstructor = nodeConstructor;
24+
}
25+
26+
public void setAftBridge(AftBridgeBiDataStream<Solution_, NewA, NewB> aftStream) {
27+
this.aftStream = aftStream;
28+
}
29+
30+
@Override
31+
public void buildNode(DataNodeBuildHelper<Solution_> buildHelper) {
32+
var aftStreamChildList = aftStream.getChildStreamList();
33+
nodeConstructor.build(buildHelper, parent.getTupleSource(), aftStream, aftStreamChildList, this,
34+
dataStreamFactory.getEnvironmentMode());
35+
}
36+
37+
@Override
38+
public boolean equals(Object object) {
39+
if (this == object)
40+
return true;
41+
if (object == null || getClass() != object.getClass())
42+
return false;
43+
var that = (BiGroupBiDataStream<?, ?, ?, ?, ?>) object;
44+
return Objects.equals(parent, that.parent) && Objects.equals(nodeConstructor, that.nodeConstructor);
45+
}
46+
47+
@Override
48+
public int hashCode() {
49+
return Objects.hash(parent, nodeConstructor);
50+
}
51+
52+
@Override
53+
public String toString() {
54+
return "BiGroup()";
55+
}
56+
57+
}

core/src/main/java/ai/timefold/solver/core/impl/move/streams/dataset/uni/AbstractUniDataStream.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package ai.timefold.solver.core.impl.move.streams.dataset.uni;
22

3+
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream;
4+
import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor;
5+
import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple;
6+
import ai.timefold.solver.core.impl.bavet.uni.Group1Mapping0CollectorUniNode;
37
import ai.timefold.solver.core.impl.move.streams.dataset.DataStreamFactory;
48
import ai.timefold.solver.core.impl.move.streams.dataset.bi.JoinBiDataStream;
59
import ai.timefold.solver.core.impl.move.streams.dataset.common.AbstractDataStream;
@@ -12,10 +16,15 @@
1216
import ai.timefold.solver.core.impl.move.streams.maybeapi.UniDataMapper;
1317
import ai.timefold.solver.core.impl.move.streams.maybeapi.stream.BiDataStream;
1418
import ai.timefold.solver.core.impl.move.streams.maybeapi.stream.UniDataStream;
15-
19+
import ai.timefold.solver.core.impl.util.ConstantLambdaUtils;
1620
import org.jspecify.annotations.NullMarked;
1721
import org.jspecify.annotations.Nullable;
1822

23+
import java.util.Objects;
24+
import java.util.function.Function;
25+
26+
import static ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor.oneKeyGroupBy;
27+
1928
@NullMarked
2029
public abstract class AbstractUniDataStream<Solution_, A> extends AbstractDataStream<Solution_>
2130
implements UniDataStream<Solution_, A> {
@@ -89,6 +98,30 @@ private <B> UniDataStream<Solution_, A> ifExistsOrNot(boolean shouldExist, UniDa
8998
joinerComber.mergedJoiner(), joinerComber.mergedFiltering()), childStreamList::add);
9099
}
91100

101+
/**
102+
* Convert the {@link UniConstraintStream} to a different {@link UniConstraintStream},
103+
* containing the set of tuples resulting from applying the group key mapping function
104+
* on all tuples of the original stream.
105+
* Neither tuple of the new stream {@link Objects#equals(Object, Object)} any other.
106+
*
107+
* @param groupKeyMapping mapping function to convert each element in the stream to a different element
108+
* @param <GroupKey_> the type of a fact in the destination {@link UniConstraintStream}'s tuple;
109+
* must honor {@link Object#hashCode() the general contract of hashCode}.
110+
*/
111+
protected <GroupKey_> AbstractUniDataStream<Solution_, GroupKey_> groupBy(Function<A, GroupKey_> groupKeyMapping) {
112+
// We do not expose this on the API, as this operation is not yet needed in any of the moves.
113+
// The groupBy API will need revisiting if exposed as a feature of Move Streams, do not expose as is.
114+
GroupNodeConstructor<UniTuple<GroupKey_>> nodeConstructor =
115+
oneKeyGroupBy(groupKeyMapping, Group1Mapping0CollectorUniNode::new);
116+
return buildUniGroupBy(nodeConstructor);
117+
}
118+
119+
private <NewA> AbstractUniDataStream<Solution_, NewA> buildUniGroupBy(GroupNodeConstructor<UniTuple<NewA>> nodeConstructor) {
120+
var stream = shareAndAddChild(new UniGroupUniDataStream<>(dataStreamFactory, this, nodeConstructor));
121+
return dataStreamFactory.share(new AftBridgeUniDataStream<>(dataStreamFactory, stream),
122+
stream::setAftBridge);
123+
}
124+
92125
@Override
93126
public <ResultA_> UniDataStream<Solution_, ResultA_> map(UniDataMapper<Solution_, A, ResultA_> mapping) {
94127
var stream = shareAndAddChild(new UniMapUniDataStream<>(dataStreamFactory, this, mapping));
@@ -107,7 +140,7 @@ public AbstractUniDataStream<Solution_, A> distinct() {
107140
if (guaranteesDistinct()) {
108141
return this; // Already distinct, no need to create a new stream.
109142
}
110-
throw new UnsupportedOperationException();
143+
return groupBy(ConstantLambdaUtils.identity());
111144
}
112145

113146
public UniDataset<Solution_, A> createDataset() {
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package ai.timefold.solver.core.impl.move.streams.dataset.uni;
2+
3+
import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor;
4+
import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple;
5+
import ai.timefold.solver.core.impl.move.streams.dataset.DataStreamFactory;
6+
import ai.timefold.solver.core.impl.move.streams.dataset.common.DataNodeBuildHelper;
7+
import ai.timefold.solver.core.impl.move.streams.dataset.common.bridge.AftBridgeUniDataStream;
8+
import org.jspecify.annotations.NullMarked;
9+
import org.jspecify.annotations.Nullable;
10+
11+
import java.util.Objects;
12+
13+
@NullMarked
14+
final class UniGroupUniDataStream<Solution_, A, NewA>
15+
extends AbstractUniDataStream<Solution_, A> {
16+
17+
private final GroupNodeConstructor<UniTuple<NewA>> nodeConstructor;
18+
private @Nullable AftBridgeUniDataStream<Solution_, NewA> aftStream;
19+
20+
public UniGroupUniDataStream(DataStreamFactory<Solution_> dataStreamFactory, AbstractUniDataStream<Solution_, A> parent,
21+
GroupNodeConstructor<UniTuple<NewA>> nodeConstructor) {
22+
super(dataStreamFactory, parent);
23+
this.nodeConstructor = nodeConstructor;
24+
}
25+
26+
public void setAftBridge(AftBridgeUniDataStream<Solution_, NewA> aftStream) {
27+
this.aftStream = aftStream;
28+
}
29+
30+
// ************************************************************************
31+
// Node creation
32+
// ************************************************************************
33+
34+
@Override
35+
public void buildNode(DataNodeBuildHelper<Solution_> buildHelper) {
36+
var aftStreamChildList = aftStream.getChildStreamList();
37+
nodeConstructor.build(buildHelper, parent.getTupleSource(), aftStream, aftStreamChildList, this,
38+
dataStreamFactory.getEnvironmentMode());
39+
}
40+
41+
// ************************************************************************
42+
// Equality for node sharing
43+
// ************************************************************************
44+
45+
@Override
46+
public boolean equals(Object object) {
47+
if (this == object)
48+
return true;
49+
if (object == null || getClass() != object.getClass())
50+
return false;
51+
var that = (UniGroupUniDataStream<?, ?, ?>) object;
52+
return Objects.equals(parent, that.parent) && Objects.equals(nodeConstructor, that.nodeConstructor);
53+
}
54+
55+
@Override
56+
public int hashCode() {
57+
return Objects.hash(parent, nodeConstructor);
58+
}
59+
60+
@Override
61+
public String toString() {
62+
return "UniGroup()";
63+
}
64+
65+
}

core/src/test/java/ai/timefold/solver/core/impl/move/MoveStreamsBasedLocalSearchTest.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
package ai.timefold.solver.core.impl.move;
22

3-
import static org.assertj.core.api.Assertions.assertThatCode;
4-
import static org.mockito.Mockito.doReturn;
5-
import static org.mockito.Mockito.mock;
6-
7-
import java.util.HashSet;
8-
import java.util.Random;
9-
103
import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore;
114
import ai.timefold.solver.core.api.score.calculator.EasyScoreCalculator;
125
import ai.timefold.solver.core.config.localsearch.decider.acceptor.LocalSearchAcceptorConfig;
136
import ai.timefold.solver.core.config.localsearch.decider.forager.LocalSearchForagerConfig;
7+
import ai.timefold.solver.core.config.solver.EnvironmentMode;
148
import ai.timefold.solver.core.config.solver.termination.TerminationConfig;
159
import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor;
1610
import ai.timefold.solver.core.impl.heuristic.HeuristicConfigPolicy;
@@ -30,10 +24,16 @@
3024
import ai.timefold.solver.core.testdomain.TestdataEntity;
3125
import ai.timefold.solver.core.testdomain.TestdataSolution;
3226
import ai.timefold.solver.core.testdomain.TestdataValue;
33-
3427
import org.jspecify.annotations.NonNull;
3528
import org.junit.jupiter.api.Test;
3629

30+
import java.util.HashSet;
31+
import java.util.Random;
32+
33+
import static org.assertj.core.api.Assertions.assertThatCode;
34+
import static org.mockito.Mockito.doReturn;
35+
import static org.mockito.Mockito.mock;
36+
3737
class MoveStreamsBasedLocalSearchTest {
3838

3939
@Test
@@ -95,7 +95,7 @@ void changeMoveBasedLocalSearch() {
9595
.genuineVariable()
9696
.ensurePlanningVariable();
9797
var moveProvider = new ChangeMoveProvider<>(variableMetaModel);
98-
var moveStreamFactory = new DefaultMoveStreamFactory<>(solutionDescriptor);
98+
var moveStreamFactory = new DefaultMoveStreamFactory<>(solutionDescriptor, EnvironmentMode.PHASE_ASSERT);
9999
var moveProducer = moveProvider.apply(moveStreamFactory);
100100
// Random selection otherwise LS gets stuck in an endless loop.
101101
return new MoveStreamsBasedMoveRepository<>(moveStreamFactory, moveProducer, true);

0 commit comments

Comments
 (0)