3030#include < string>
3131#include < vector>
3232#include < utility>
33+ #include < set>
3334
3435using namespace std ;
3536using std::stringstream;
@@ -38,38 +39,97 @@ using std::pair;
3839namespace gtsam {
3940
4041// Instantiate base class
41- template class GTSAM_EXPORT Conditional<DecisionTreeFactor, DiscreteConditional> ;
42+ template class GTSAM_EXPORT
43+ Conditional<DecisionTreeFactor, DiscreteConditional>;
4244
43- /* ******************************************************************************** */
45+ /* ************************************************************************** */
4446DiscreteConditional::DiscreteConditional (const size_t nrFrontals,
45- const DecisionTreeFactor& f) :
46- BaseFactor (f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
47- }
47+ const DecisionTreeFactor& f)
48+ : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
4849
49- /* ******************************************************************************** */
50+ /* ************************************************************************** */
51+ DiscreteConditional::DiscreteConditional (size_t nrFrontals,
52+ const DiscreteKeys& keys,
53+ const ADT& potentials)
54+ : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
55+
56+ /* ************************************************************************** */
5057DiscreteConditional::DiscreteConditional (const DecisionTreeFactor& joint,
51- const DecisionTreeFactor& marginal) :
52- BaseFactor (
53- ISDEBUG (" DiscreteConditional::COUNT" ) ? joint : joint / marginal), BaseConditional(
54- joint.size()-marginal.size()) {
55- if (ISDEBUG (" DiscreteConditional::DiscreteConditional" ))
56- cout << (firstFrontalKey ()) << endl; // TODO Print all keys
57- }
58+ const DecisionTreeFactor& marginal)
59+ : BaseFactor(joint / marginal),
60+ BaseConditional (joint.size() - marginal.size()) {}
5861
59- /* ******************************************************************************** */
62+ /* ************************************************************************** */
6063DiscreteConditional::DiscreteConditional (const DecisionTreeFactor& joint,
61- const DecisionTreeFactor& marginal, const Ordering& orderedKeys) :
62- DiscreteConditional(joint, marginal) {
64+ const DecisionTreeFactor& marginal,
65+ const Ordering& orderedKeys)
66+ : DiscreteConditional(joint, marginal) {
6367 keys_.clear ();
6468 keys_.insert (keys_.end (), orderedKeys.begin (), orderedKeys.end ());
6569}
6670
67- /* ******************************************************************************** */
71+ /* ************************************************************************** */
6872DiscreteConditional::DiscreteConditional (const Signature& signature)
6973 : BaseFactor(signature.discreteKeys(), signature.cpt()),
7074 BaseConditional(1 ) {}
7175
72- /* ******************************************************************************** */
76+ /* ************************************************************************** */
77+ DiscreteConditional DiscreteConditional::operator *(
78+ const DiscreteConditional& other) const {
79+ // Take union of frontal keys
80+ std::set<Key> newFrontals;
81+ for (auto && key : this ->frontals ()) newFrontals.insert (key);
82+ for (auto && key : other.frontals ()) newFrontals.insert (key);
83+
84+ // Check if frontals overlapped
85+ if (nrFrontals () + other.nrFrontals () > newFrontals.size ())
86+ throw std::invalid_argument (
87+ " DiscreteConditional::operator* called with overlapping frontal keys." );
88+
89+ // Now, add cardinalities.
90+ DiscreteKeys discreteKeys;
91+ for (auto && key : frontals ())
92+ discreteKeys.emplace_back (key, cardinality (key));
93+ for (auto && key : other.frontals ())
94+ discreteKeys.emplace_back (key, other.cardinality (key));
95+
96+ // Sort
97+ std::sort (discreteKeys.begin (), discreteKeys.end ());
98+
99+ // Add parents to set, to make them unique
100+ std::set<DiscreteKey> parents;
101+ for (auto && key : this ->parents ())
102+ if (!newFrontals.count (key)) parents.emplace (key, cardinality (key));
103+ for (auto && key : other.parents ())
104+ if (!newFrontals.count (key)) parents.emplace (key, other.cardinality (key));
105+
106+ // Finally, add parents to keys, in order
107+ for (auto && dk : parents) discreteKeys.push_back (dk);
108+
109+ ADT product = ADT::apply (other, ADT::Ring::mul);
110+ return DiscreteConditional (newFrontals.size (), discreteKeys, product);
111+ }
112+
113+ /* ************************************************************************** */
114+ DiscreteConditional DiscreteConditional::marginal (Key key) const {
115+ if (nrParents () > 0 )
116+ throw std::invalid_argument (
117+ " DiscreteConditional::marginal: single argument version only valid for "
118+ " fully specified joint distributions (i.e., no parents)." );
119+
120+ // Calculate the keys as the frontal keys without the given key.
121+ DiscreteKeys discreteKeys{{key, cardinality (key)}};
122+
123+ // Calculate sum
124+ ADT adt (*this );
125+ for (auto && k : frontals ())
126+ if (k != key) adt = adt.sum (k, cardinality (k));
127+
128+ // Return new factor
129+ return DiscreteConditional (1 , discreteKeys, adt);
130+ }
131+
132+ /* ************************************************************************** */
73133void DiscreteConditional::print (const string& s,
74134 const KeyFormatter& formatter) const {
75135 cout << s << " P( " ;
@@ -82,7 +142,7 @@ void DiscreteConditional::print(const string& s,
82142 cout << formatter (*it) << " " ;
83143 }
84144 }
85- cout << " )" ;
145+ cout << " ): \n " ;
86146 ADT::print (" " );
87147 cout << endl;
88148}
0 commit comments