Skip to content

Commit 76c7419

Browse files
authored
Merge pull request #1041 from borglab/release/4.2a3
2 parents d6f3468 + 91de3cb commit 76c7419

22 files changed

+527
-175
lines changed

.github/workflows/build-linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
BOOST_VERSION: 1.67.0
1616

1717
strategy:
18-
fail-fast: false
18+
fail-fast: true
1919
matrix:
2020
# Github Actions requires a single row to be added to the build matrix.
2121
# See https://help.github.com/en/articles/workflow-syntax-for-github-actions.

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ endif()
1111
set (GTSAM_VERSION_MAJOR 4)
1212
set (GTSAM_VERSION_MINOR 2)
1313
set (GTSAM_VERSION_PATCH 0)
14-
set (GTSAM_PRERELEASE_VERSION "a2")
14+
set (GTSAM_PRERELEASE_VERSION "a3")
1515
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
1616

1717
if (${GTSAM_VERSION_PATCH} EQUAL 0)

gtsam/discrete/DecisionTreeFactor.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ namespace gtsam {
5757
/** Default constructor for I/O */
5858
DecisionTreeFactor();
5959

60-
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
60+
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
6161
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
6262

6363
/** Constructor from doubles */
@@ -139,22 +139,22 @@ namespace gtsam {
139139
/**
140140
* Apply binary operator (*this) "op" f
141141
* @param f the second argument for op
142-
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
142+
* @param op a binary operator that operates on AlgebraicDecisionTree
143143
*/
144144
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
145145

146146
/**
147147
* Combine frontal variables using binary operator "op"
148148
* @param nrFrontals nr. of frontal to combine variables in this factor
149-
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
149+
* @param op a binary operator that operates on AlgebraicDecisionTree
150150
* @return shared pointer to newly created DecisionTreeFactor
151151
*/
152152
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
153153

154154
/**
155155
* Combine frontal variables in an Ordering using binary operator "op"
156156
* @param nrFrontals nr. of frontal to combine variables in this factor
157-
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
157+
* @param op a binary operator that operates on AlgebraicDecisionTree
158158
* @return shared pointer to newly created DecisionTreeFactor
159159
*/
160160
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;

gtsam/discrete/DiscreteBayesNet.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#pragma once
2020

2121
#include <gtsam/discrete/DiscreteConditional.h>
22-
#include <gtsam/discrete/DiscretePrior.h>
22+
#include <gtsam/discrete/DiscreteDistribution.h>
2323
#include <gtsam/inference/BayesNet.h>
2424
#include <gtsam/inference/FactorGraph.h>
2525

@@ -79,9 +79,9 @@ namespace gtsam {
7979
// Add inherited versions of add.
8080
using Base::add;
8181

82-
/** Add a DiscretePrior using a table or a string */
82+
/** Add a DiscreteDistribution using a table or a string */
8383
void add(const DiscreteKey& key, const std::string& spec) {
84-
emplace_shared<DiscretePrior>(key, spec);
84+
emplace_shared<DiscreteDistribution>(key, spec);
8585
}
8686

8787
/** Add a DiscreteCondtional */

gtsam/discrete/DiscreteConditional.cpp

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <string>
3131
#include <vector>
3232
#include <utility>
33+
#include <set>
3334

3435
using namespace std;
3536
using std::stringstream;
@@ -38,38 +39,97 @@ using std::pair;
3839
namespace gtsam {
3940

4041
// Instantiate base class
41-
template class GTSAM_EXPORT Conditional<DecisionTreeFactor, DiscreteConditional> ;
42+
template class GTSAM_EXPORT
43+
Conditional<DecisionTreeFactor, DiscreteConditional>;
4244

43-
/* ******************************************************************************** */
45+
/* ************************************************************************** */
4446
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
45-
const DecisionTreeFactor& f) :
46-
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
47-
}
47+
const DecisionTreeFactor& f)
48+
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
4849

49-
/* ******************************************************************************** */
50+
/* ************************************************************************** */
51+
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
52+
const DiscreteKeys& keys,
53+
const ADT& potentials)
54+
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
55+
56+
/* ************************************************************************** */
5057
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
51-
const DecisionTreeFactor& marginal) :
52-
BaseFactor(
53-
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
54-
joint.size()-marginal.size()) {
55-
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
56-
cout << (firstFrontalKey()) << endl; //TODO Print all keys
57-
}
58+
const DecisionTreeFactor& marginal)
59+
: BaseFactor(joint / marginal),
60+
BaseConditional(joint.size() - marginal.size()) {}
5861

59-
/* ******************************************************************************** */
62+
/* ************************************************************************** */
6063
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
61-
const DecisionTreeFactor& marginal, const Ordering& orderedKeys) :
62-
DiscreteConditional(joint, marginal) {
64+
const DecisionTreeFactor& marginal,
65+
const Ordering& orderedKeys)
66+
: DiscreteConditional(joint, marginal) {
6367
keys_.clear();
6468
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
6569
}
6670

67-
/* ******************************************************************************** */
71+
/* ************************************************************************** */
6872
DiscreteConditional::DiscreteConditional(const Signature& signature)
6973
: BaseFactor(signature.discreteKeys(), signature.cpt()),
7074
BaseConditional(1) {}
7175

72-
/* ******************************************************************************** */
76+
/* ************************************************************************** */
77+
DiscreteConditional DiscreteConditional::operator*(
78+
const DiscreteConditional& other) const {
79+
// Take union of frontal keys
80+
std::set<Key> newFrontals;
81+
for (auto&& key : this->frontals()) newFrontals.insert(key);
82+
for (auto&& key : other.frontals()) newFrontals.insert(key);
83+
84+
// Check if frontals overlapped
85+
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
86+
throw std::invalid_argument(
87+
"DiscreteConditional::operator* called with overlapping frontal keys.");
88+
89+
// Now, add cardinalities.
90+
DiscreteKeys discreteKeys;
91+
for (auto&& key : frontals())
92+
discreteKeys.emplace_back(key, cardinality(key));
93+
for (auto&& key : other.frontals())
94+
discreteKeys.emplace_back(key, other.cardinality(key));
95+
96+
// Sort
97+
std::sort(discreteKeys.begin(), discreteKeys.end());
98+
99+
// Add parents to set, to make them unique
100+
std::set<DiscreteKey> parents;
101+
for (auto&& key : this->parents())
102+
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
103+
for (auto&& key : other.parents())
104+
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
105+
106+
// Finally, add parents to keys, in order
107+
for (auto&& dk : parents) discreteKeys.push_back(dk);
108+
109+
ADT product = ADT::apply(other, ADT::Ring::mul);
110+
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
111+
}
112+
113+
/* ************************************************************************** */
114+
DiscreteConditional DiscreteConditional::marginal(Key key) const {
115+
if (nrParents() > 0)
116+
throw std::invalid_argument(
117+
"DiscreteConditional::marginal: single argument version only valid for "
118+
"fully specified joint distributions (i.e., no parents).");
119+
120+
// Calculate the keys as the frontal keys without the given key.
121+
DiscreteKeys discreteKeys{{key, cardinality(key)}};
122+
123+
// Calculate sum
124+
ADT adt(*this);
125+
for (auto&& k : frontals())
126+
if (k != key) adt = adt.sum(k, cardinality(k));
127+
128+
// Return new factor
129+
return DiscreteConditional(1, discreteKeys, adt);
130+
}
131+
132+
/* ************************************************************************** */
73133
void DiscreteConditional::print(const string& s,
74134
const KeyFormatter& formatter) const {
75135
cout << s << " P( ";
@@ -82,7 +142,7 @@ void DiscreteConditional::print(const string& s,
82142
cout << formatter(*it) << " ";
83143
}
84144
}
85-
cout << ")";
145+
cout << "):\n";
86146
ADT::print("");
87147
cout << endl;
88148
}

gtsam/discrete/DiscreteConditional.h

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,21 @@ class GTSAM_EXPORT DiscreteConditional
4949
/// @name Standard Constructors
5050
/// @{
5151

52-
/** default constructor needed for serialization */
52+
/// Default constructor needed for serialization.
5353
DiscreteConditional() {}
5454

55-
/** constructor from factor */
55+
/// Construct from factor, taking the first `nFrontals` keys as frontals.
5656
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
5757

58+
/**
59+
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
60+
* `nFrontals` keys as frontals, in the order given.
61+
*/
62+
DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys,
63+
const ADT& potentials);
64+
5865
/** Construct from signature */
59-
DiscreteConditional(const Signature& signature);
66+
explicit DiscreteConditional(const Signature& signature);
6067

6168
/**
6269
* Construct from key, parents, and a Signature::Table specifying the
@@ -82,31 +89,45 @@ class GTSAM_EXPORT DiscreteConditional
8289
const std::string& spec)
8390
: DiscreteConditional(Signature(key, parents, spec)) {}
8491

85-
/// No-parent specialization; can also use DiscretePrior.
92+
/// No-parent specialization; can also use DiscreteDistribution.
8693
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
8794
: DiscreteConditional(Signature(key, {}, spec)) {}
8895

89-
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
96+
/**
97+
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
98+
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
99+
*/
90100
DiscreteConditional(const DecisionTreeFactor& joint,
91101
const DecisionTreeFactor& marginal);
92102

93-
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
103+
/**
104+
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
105+
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
106+
* Makes sure the keys are ordered as given. Does not check orderedKeys.
107+
*/
94108
DiscreteConditional(const DecisionTreeFactor& joint,
95109
const DecisionTreeFactor& marginal,
96110
const Ordering& orderedKeys);
97111

98112
/**
99-
* Combine several conditional into a single one.
100-
* The conditionals must be given in increasing order, meaning that the
101-
* parents of any conditional may not include a conditional coming before it.
102-
* @param firstConditional Iterator to the first conditional to combine, must
103-
* dereference to a shared_ptr<DiscreteConditional>.
104-
* @param lastConditional Iterator to after the last conditional to combine,
105-
* must dereference to a shared_ptr<DiscreteConditional>.
106-
* */
107-
template <typename ITERATOR>
108-
static shared_ptr Combine(ITERATOR firstConditional,
109-
ITERATOR lastConditional);
113+
* @brief Combine two conditionals, yielding a new conditional with the union
114+
* of the frontal keys, ordered by gtsam::Key.
115+
*
116+
* The two conditionals must make a valid Bayes net fragment, i.e.,
117+
* the frontal variables cannot overlap, and must be acyclic:
118+
* Example of correct use:
119+
* P(A,B) = P(A|B) * P(B)
120+
* P(A,B|C) = P(A|B) * P(B|C)
121+
* P(A,B,C) = P(A,B|C) * P(C)
122+
* Example of incorrect use:
123+
* P(A|B) * P(A|C) = ?
124+
* P(A|B) * P(B|A) = ?
125+
* We check for overlapping frontals, but do *not* check for cyclic.
126+
*/
127+
DiscreteConditional operator*(const DiscreteConditional& other) const;
128+
129+
/** Calculate marginal on given key, no parent case. */
130+
DiscreteConditional marginal(Key key) const;
110131

111132
/// @}
112133
/// @name Testable
@@ -136,11 +157,6 @@ class GTSAM_EXPORT DiscreteConditional
136157
return ADT::operator()(values);
137158
}
138159

139-
/** Convert to a factor */
140-
DecisionTreeFactor::shared_ptr toFactor() const {
141-
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
142-
}
143-
144160
/** Restrict to given parent values, returns DecisionTreeFactor */
145161
DecisionTreeFactor::shared_ptr choose(
146162
const DiscreteValues& parentsValues) const;
@@ -208,23 +224,4 @@ class GTSAM_EXPORT DiscreteConditional
208224
template <>
209225
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
210226

211-
/* ************************************************************************* */
212-
template <typename ITERATOR>
213-
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
214-
ITERATOR firstConditional, ITERATOR lastConditional) {
215-
// TODO: check for being a clique
216-
217-
// multiply all the potentials of the given conditionals
218-
size_t nrFrontals = 0;
219-
DecisionTreeFactor product;
220-
for (ITERATOR it = firstConditional; it != lastConditional;
221-
++it, ++nrFrontals) {
222-
DiscreteConditional::shared_ptr c = *it;
223-
DecisionTreeFactor::shared_ptr factor = c->toFactor();
224-
product = (*factor) * product;
225-
}
226-
// and then create a new multi-frontal conditional
227-
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
228-
}
229-
230227
} // namespace gtsam
Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,23 @@
1010
* -------------------------------------------------------------------------- */
1111

1212
/**
13-
* @file DiscretePrior.cpp
13+
* @file DiscreteDistribution.cpp
1414
* @date December 2021
1515
* @author Frank Dellaert
1616
*/
1717

18-
#include <gtsam/discrete/DiscretePrior.h>
18+
#include <gtsam/discrete/DiscreteDistribution.h>
19+
20+
#include <vector>
1921

2022
namespace gtsam {
2123

22-
void DiscretePrior::print(const std::string& s,
23-
const KeyFormatter& formatter) const {
24+
void DiscreteDistribution::print(const std::string& s,
25+
const KeyFormatter& formatter) const {
2426
Base::print(s, formatter);
2527
}
2628

27-
double DiscretePrior::operator()(size_t value) const {
29+
double DiscreteDistribution::operator()(size_t value) const {
2830
if (nrFrontals() != 1)
2931
throw std::invalid_argument(
3032
"Single value operator can only be invoked on single-variable "
@@ -34,10 +36,10 @@ double DiscretePrior::operator()(size_t value) const {
3436
return Base::operator()(values);
3537
}
3638

37-
std::vector<double> DiscretePrior::pmf() const {
39+
std::vector<double> DiscreteDistribution::pmf() const {
3840
if (nrFrontals() != 1)
3941
throw std::invalid_argument(
40-
"DiscretePrior::pmf only defined for single-variable priors");
42+
"DiscreteDistribution::pmf only defined for single-variable priors");
4143
const size_t nrValues = cardinalities_.at(keys_[0]);
4244
std::vector<double> array;
4345
array.reserve(nrValues);

0 commit comments

Comments
 (0)