@@ -140,11 +140,46 @@ TEST(DecisionTreeFactor, enumerate) {
140140 EXPECT (actual == expected);
141141}
142142
143+ namespace pruning_fixture {
144+
145+ DiscreteKey A (1 , 2 ), B(2 , 2 ), C(3 , 2 );
146+ DecisionTreeFactor f (A& B& C, " 1 5 3 7 2 6 4 8" );
147+
148+ DiscreteKey D (4 , 2 );
149+ DecisionTreeFactor factor (
150+ D& C & B & A,
151+ " 0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
152+ " 0.0 0.0 0.99995287 1.0 1.0 1.0 1.0" );
153+
154+ } // namespace pruning_fixture
155+
156+ /* ************************************************************************* */
157+ // Check if computing the correct threshold works.
158+ TEST (DecisionTreeFactor, ComputeThreshold) {
159+ using namespace pruning_fixture ;
160+
161+ // Only keep the leaves with the top 5 values.
162+ double threshold = f.computeThreshold (5 );
163+ EXPECT_DOUBLES_EQUAL (4.0 , threshold, 1e-9 );
164+
165+ // Check for more extreme pruning where we only keep the top 2 leaves
166+ threshold = f.computeThreshold (2 );
167+ EXPECT_DOUBLES_EQUAL (7.0 , threshold, 1e-9 );
168+
169+ threshold = factor.computeThreshold (5 );
170+ EXPECT_DOUBLES_EQUAL (0.99995287 , threshold, 1e-9 );
171+
172+ threshold = factor.computeThreshold (3 );
173+ EXPECT_DOUBLES_EQUAL (1.0 , threshold, 1e-9 );
174+
175+ threshold = factor.computeThreshold (6 );
176+ EXPECT_DOUBLES_EQUAL (0.61247742 , threshold, 1e-9 );
177+ }
178+
143179/* ************************************************************************* */
144180// Check pruning of the decision tree works as expected.
145181TEST (DecisionTreeFactor, Prune) {
146- DiscreteKey A (1 , 2 ), B (2 , 2 ), C (3 , 2 );
147- DecisionTreeFactor f (A & B & C, " 1 5 3 7 2 6 4 8" );
182+ using namespace pruning_fixture ;
148183
149184 // Only keep the leaves with the top 5 values.
150185 size_t maxNrAssignments = 5 ;
@@ -160,12 +195,6 @@ TEST(DecisionTreeFactor, Prune) {
160195 DecisionTreeFactor expected2 (A & B & C, " 0 0 0 7 0 0 0 8" );
161196 EXPECT (assert_equal (expected2, pruned2));
162197
163- DiscreteKey D (4 , 2 );
164- DecisionTreeFactor factor (
165- D & C & B & A,
166- " 0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
167- " 0.0 0.0 0.99995287 1.0 1.0 1.0 1.0" );
168-
169198 DecisionTreeFactor expected3 (D & C & B & A,
170199 " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
171200 " 0.999952870000 1.0 1.0 1.0 1.0" );
0 commit comments