Skip to content

Commit 4f66a49

Browse files
authored
Merge pull request #1619 from borglab/release/4.2
2 parents a82f191 + 1a86944 commit 4f66a49

32 files changed

+918
-222
lines changed

gtsam/discrete/AlgebraicDecisionTree.h

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
namespace gtsam {
2929

3030
/**
31-
* Algebraic Decision Trees fix the range to double
32-
* Just has some nice constructors and some syntactic sugar
33-
* TODO: consider eliminating this class altogether?
31+
* An algebraic decision tree fixes the range of a DecisionTree to double.
32+
* Just has some nice constructors and some syntactic sugar.
33+
* TODO(dellaert): consider eliminating this class altogether?
3434
*
3535
* @ingroup discrete
3636
*/
@@ -80,20 +80,62 @@ namespace gtsam {
8080
AlgebraicDecisionTree(const L& label, double y1, double y2)
8181
: Base(label, y1, y2) {}
8282

83-
/** Create a new leaf function splitting on a variable */
83+
/**
84+
* @brief Create a new leaf function splitting on a variable
85+
*
86+
* @param labelC: The label with cardinality 2
87+
* @param y1: The value for the first key
88+
* @param y2: The value for the second key
89+
*
90+
* Example:
91+
* @code{.cpp}
92+
* std::pair<string, size_t> A {"a", 2};
93+
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
94+
* @endcode
95+
*/
8496
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
8597
double y2)
8698
: Base(labelC, y1, y2) {}
8799

88-
/** Create from keys and vector table */
100+
/**
101+
* @brief Create from keys with cardinalities and a vector table
102+
*
103+
* @param labelCs: The keys, with cardinalities, given as pairs
104+
* @param ys: The vector table
105+
*
106+
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
107+
* respectively, and a vector table of size 12:
108+
* @code{.cpp}
109+
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
110+
* const vector<double> cpt{
111+
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
112+
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
113+
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
114+
* @endcode
115+
* The table is given in the following order:
116+
* A=0, B=0, C=0
117+
* A=0, B=0, C=1
118+
* ...
119+
* A=1, B=1, C=1
120+
* Hence, the first line in the table is for A==0, and the second for A==1.
121+
* In each line, the first two entries are for B==0, the next two for B==1,
122+
* and the last two for B==2. Each pair is for a C value of 0 and 1.
123+
*/
89124
AlgebraicDecisionTree //
90125
(const std::vector<typename Base::LabelC>& labelCs,
91-
const std::vector<double>& ys) {
126+
const std::vector<double>& ys) {
92127
this->root_ =
93128
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
94129
}
95130

96-
/** Create from keys and string table */
131+
/**
132+
* @brief Create from keys and string table
133+
*
134+
* @param labelCs: The keys, with cardinalities, given as pairs
135+
* @param table: The string table, given as a string of doubles.
136+
*
137+
* @note Table needs to be in same order as the vector table in the other constructor.
138+
*/
97139
AlgebraicDecisionTree //
98140
(const std::vector<typename Base::LabelC>& labelCs,
99141
const std::string& table) {
@@ -108,7 +150,13 @@ namespace gtsam {
108150
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
109151
}
110152

111-
/** Create a new function splitting on a variable */
153+
/**
154+
* @brief Create a range of decision trees, splitting on a single variable.
155+
*
156+
* @param begin: Iterator to beginning of a range of decision trees
157+
* @param end: Iterator to end of a range of decision trees
158+
* @param label: The label to split on
159+
*/
112160
template <typename Iterator>
113161
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
114162
: Base(nullptr) {

gtsam/discrete/DecisionTree-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ namespace gtsam {
622622
// B=1
623623
// A=0: 3
624624
// A=1: 4
625-
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
625+
// Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
626626
// exactly the same tree as above: the highest label is always the root.
627627
// However, it will be *way* faster if labels are given highest to lowest.
628628
template<typename L, typename Y>

gtsam/discrete/DecisionTree.h

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,23 @@
3737
namespace gtsam {
3838

3939
/**
40-
* Decision Tree
41-
* L = label for variables
42-
* Y = function range (any algebra), e.g., bool, int, double
40+
* @brief a decision tree is a function from assignments to values.
41+
* @tparam L label for variables
42+
* @tparam Y function range (any algebra), e.g., bool, int, double
43+
*
44+
* After creating a decision tree on some variables, the tree can be evaluated
45+
* on an assignment to those variables. Example:
46+
*
47+
* @code{.cpp}
48+
* // Create a decision stump one one variable 'a' with values 10 and 20.
49+
* DecisionTree<char, int> tree('a', 10, 20);
50+
*
51+
* // Evaluate the tree on an assignment to the variable.
52+
* int value0 = tree({{'a', 0}}); // value0 = 10
53+
* int value1 = tree({{'a', 1}}); // value1 = 20
54+
* @endcode
55+
*
56+
* More examples can be found in testDecisionTree.cpp
4357
*
4458
* @ingroup discrete
4559
*/
@@ -132,7 +146,8 @@ namespace gtsam {
132146
NodePtr root_;
133147

134148
protected:
135-
/** Internal recursive function to create from keys, cardinalities,
149+
/**
150+
* Internal recursive function to create from keys, cardinalities,
136151
* and Y values
137152
*/
138153
template<typename It, typename ValueIt>
@@ -163,7 +178,13 @@ namespace gtsam {
163178
/** Create a constant */
164179
explicit DecisionTree(const Y& y);
165180

166-
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
181+
/**
182+
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
183+
*
184+
* @param label The variable to split on.
185+
* @param y1 The value for the first assignment.
186+
* @param y2 The value for the second assignment.
187+
*/
167188
DecisionTree(const L& label, const Y& y1, const Y& y2);
168189

169190
/** Allow Label+Cardinality for convenience */

gtsam/discrete/DecisionTreeFactor.h

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,46 @@ namespace gtsam {
6363
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
6464
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
6565

66-
/** Constructor from doubles */
66+
/**
67+
* @brief Constructor from doubles
68+
*
69+
* @param keys The discrete keys.
70+
* @param table The table of values.
71+
*
72+
* @throw std::invalid_argument if the size of `table` does not match the
73+
* number of assignments.
74+
*
75+
* Example:
76+
* @code{.cpp}
77+
* DiscreteKey X(0,2), Y(1,3);
78+
* const std::vector<double> table {2, 5, 3, 6, 4, 7};
79+
* DecisionTreeFactor f1({X, Y}, table);
80+
* @endcode
81+
*
82+
* The values in the table should be laid out so that the first key varies
83+
* the slowest, and the last key the fastest.
84+
*/
6785
DecisionTreeFactor(const DiscreteKeys& keys,
68-
const std::vector<double>& table);
86+
const std::vector<double>& table);
6987

70-
/** Constructor from string */
88+
/**
89+
* @brief Constructor from string
90+
*
91+
* @param keys The discrete keys.
92+
* @param table The table of values.
93+
*
94+
* @throw std::invalid_argument if the size of `table` does not match the
95+
* number of assignments.
96+
*
97+
* Example:
98+
* @code{.cpp}
99+
* DiscreteKey X(0,2), Y(1,3);
100+
* DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7");
101+
* @endcode
102+
*
103+
* The values in the table should be laid out so that the first key varies
104+
* the slowest, and the last key the fastest.
105+
*/
71106
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
72107

73108
/// Single-key specialization

gtsam/discrete/DiscreteBayesTree.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
5959

6060
//** evaluate conditional probability of subtree for given DiscreteValues */
6161
double evaluate(const DiscreteValues& values) const;
62+
63+
//** (Preferred) sugar for the above for given DiscreteValues */
64+
double operator()(const DiscreteValues& values) const {
65+
return evaluate(values);
66+
}
6267
};
6368

6469
/* ************************************************************************* */

gtsam/discrete/DiscreteFactorGraph.h

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,30 @@ class DiscreteJunctionTree;
4242

4343
/**
4444
* @brief Main elimination function for DiscreteFactorGraph.
45-
*
46-
* @param factors
47-
* @param keys
48-
* @return GTSAM_EXPORT
45+
*
46+
* @param factors The factor graph to eliminate.
47+
* @param frontalKeys An ordering for which variables to eliminate.
48+
* @return A pair of the resulting conditional and the separator factor.
4949
* @ingroup discrete
5050
*/
51-
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
52-
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
51+
GTSAM_EXPORT
52+
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
53+
EliminateDiscrete(const DiscreteFactorGraph& factors,
54+
const Ordering& frontalKeys);
55+
56+
/**
57+
* @brief Alternate elimination function for that creates non-normalized lookup tables.
58+
*
59+
* @param factors The factor graph to eliminate.
60+
* @param frontalKeys An ordering for which variables to eliminate.
61+
* @return A pair of the resulting lookup table and the separator factor.
62+
* @ingroup discrete
63+
*/
64+
GTSAM_EXPORT
65+
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
66+
EliminateForMPE(const DiscreteFactorGraph& factors,
67+
const Ordering& frontalKeys);
5368

54-
/* ************************************************************************* */
5569
template<> struct EliminationTraits<DiscreteFactorGraph>
5670
{
5771
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
@@ -61,12 +75,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
6175
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
6276
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
6377
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
78+
6479
/// The default dense elimination function
6580
static std::pair<boost::shared_ptr<ConditionalType>,
6681
boost::shared_ptr<FactorType> >
6782
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
6883
return EliminateDiscrete(factors, keys);
6984
}
85+
7086
/// The default ordering generation function
7187
static Ordering DefaultOrderingFunc(
7288
const FactorGraphType& graph,
@@ -75,7 +91,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
7591
}
7692
};
7793

78-
/* ************************************************************************* */
7994
/**
8095
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
8196
* Factor == DiscreteFactor
@@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
109124

110125
/** Implicit copy/downcast constructor to override explicit template container
111126
* constructor */
112-
template <class DERIVEDFACTOR>
113-
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
127+
template <class DERIVED_FACTOR>
128+
DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
114129

115130
/// Destructor
116131
virtual ~DiscreteFactorGraph() {}
@@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
231246
/// @}
232247
}; // \ DiscreteFactorGraph
233248

234-
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
235-
EliminateForMPE(const DiscreteFactorGraph& factors,
236-
const Ordering& frontalKeys);
237-
238249
/// traits
239250
template <>
240251
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};

gtsam/discrete/DiscreteJunctionTree.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,6 @@ namespace gtsam {
6666
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
6767
};
6868

69+
/// typedef for wrapper:
70+
using DiscreteCluster = DiscreteJunctionTree::Cluster;
6971
}

gtsam/discrete/DiscreteValues.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
120120
/// @}
121121
};
122122

123+
/// Free version of CartesianProduct.
124+
inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
125+
return DiscreteValues::CartesianProduct(keys);
126+
}
127+
123128
/// Free version of markdown.
124129
std::string markdown(const DiscreteValues& values,
125130
const KeyFormatter& keyFormatter = DefaultKeyFormatter,

0 commit comments

Comments
 (0)