Skip to content

Commit 3e65779

Browse files
authored
Merge pull request #1095 from borglab/feature/linear_improvements
2 parents a5bee15 + 6337241 commit 3e65779

File tree

11 files changed

+189
-40
lines changed

11 files changed

+189
-40
lines changed

gtsam/base/FastSet.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
#pragma once
2020

21+
#if BOOST_VERSION >= 107400
22+
#include <boost/serialization/library_version_type.hpp>
23+
#endif
2124
#include <boost/serialization/nvp.hpp>
2225
#include <boost/serialization/set.hpp>
2326
#include <gtsam/base/FastDefaultAllocator.h>

gtsam/base/serialization.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <string>
2626

2727
// includes for standard serialization types
28+
#include <boost/serialization/version.hpp>
2829
#include <boost/serialization/optional.hpp>
2930
#include <boost/serialization/shared_ptr.hpp>
3031
#include <boost/serialization/vector.hpp>

gtsam/discrete/DiscreteConditional.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,13 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
225225

226226
/* ****************************************************************************/
227227
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
228-
size_t parent_value) const {
228+
size_t frontal) const {
229229
if (nrFrontals() != 1)
230230
throw std::invalid_argument(
231231
"Single value likelihood can only be invoked on single-variable "
232232
"conditional");
233233
DiscreteValues values;
234-
values.emplace(keys_[0], parent_value);
234+
values.emplace(keys_[0], frontal);
235235
return likelihood(values);
236236
}
237237

gtsam/discrete/DiscreteConditional.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class GTSAM_EXPORT DiscreteConditional
177177
const DiscreteValues& frontalValues) const;
178178

179179
/** Single variable version of likelihood. */
180-
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
180+
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
181181

182182
/**
183183
* sample

gtsam/linear/GaussianConditional.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,50 @@ namespace gtsam {
224224
}
225225
}
226226

227+
/* ************************************************************************ */
228+
JacobianFactor::shared_ptr GaussianConditional::likelihood(
229+
const VectorValues& frontalValues) const {
230+
// Error is |Rx - (d - Sy - Tz - ...)|^2
231+
// so when we instantiate x (which has to be completely known) we beget:
232+
// |Sy + Tz + ... - (d - Rx)|^2
233+
// The noise model just transfers over!
234+
235+
// Get frontalValues as vector
236+
const Vector x =
237+
frontalValues.vector(KeyVector(beginFrontals(), endFrontals()));
238+
239+
// Copy the augmented Jacobian matrix:
240+
auto newAb = Ab_;
241+
242+
// Restrict view to parent blocks
243+
newAb.firstBlock() += nrFrontals_;
244+
245+
// Update right-hand-side (last column)
246+
auto last = newAb.matrix().cols() - 1;
247+
const auto RR = R().triangularView<Eigen::Upper>();
248+
newAb.matrix().col(last) -= RR * x;
249+
250+
// The keys now do not include the frontal keys:
251+
KeyVector newKeys;
252+
newKeys.reserve(nrParents());
253+
for (auto&& key : parents()) newKeys.push_back(key);
254+
255+
// Hopefully second newAb copy below is optimized out...
256+
return boost::make_shared<JacobianFactor>(newKeys, newAb, model_);
257+
}
258+
259+
/* **************************************************************************/
260+
JacobianFactor::shared_ptr GaussianConditional::likelihood(
261+
const Vector& frontal) const {
262+
if (nrFrontals() != 1)
263+
throw std::invalid_argument(
264+
"GaussianConditional Single value likelihood can only be invoked on "
265+
"single-variable conditional");
266+
VectorValues values;
267+
values.insert(keys_[0], frontal);
268+
return likelihood(values);
269+
}
270+
227271
/* ************************************************************************ */
228272
VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
229273
std::mt19937_64* rng) const {

gtsam/linear/GaussianConditional.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ namespace gtsam {
151151
/** Performs transpose backsubstition in place on values */
152152
void solveTransposeInPlace(VectorValues& gy) const;
153153

154+
/** Convert to a likelihood factor by providing value before bar. */
155+
JacobianFactor::shared_ptr likelihood(
156+
const VectorValues& frontalValues) const;
157+
158+
/** Single variable version of likelihood. */
159+
JacobianFactor::shared_ptr likelihood(const Vector& frontal) const;
160+
154161
/**
155162
* Sample from conditional, zero parent version
156163
* Example:

gtsam/linear/VectorValues.cpp

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace gtsam {
3333
using boost::adaptors::map_values;
3434
using boost::accumulate;
3535

36-
/* ************************************************************************* */
36+
/* ************************************************************************ */
3737
VectorValues::VectorValues(const VectorValues& first, const VectorValues& second)
3838
{
3939
// Merge using predicate for comparing first of pair
@@ -44,7 +44,7 @@ namespace gtsam {
4444
throw invalid_argument("Requested to merge two VectorValues that have one or more variables in common.");
4545
}
4646

47-
/* ************************************************************************* */
47+
/* ************************************************************************ */
4848
VectorValues::VectorValues(const Vector& x, const Dims& dims) {
4949
using Pair = pair<const Key, size_t>;
5050
size_t j = 0;
@@ -61,7 +61,7 @@ namespace gtsam {
6161
}
6262
}
6363

64-
/* ************************************************************************* */
64+
/* ************************************************************************ */
6565
VectorValues::VectorValues(const Vector& x, const Scatter& scatter) {
6666
size_t j = 0;
6767
for (const SlotEntry& v : scatter) {
@@ -74,7 +74,7 @@ namespace gtsam {
7474
}
7575
}
7676

77-
/* ************************************************************************* */
77+
/* ************************************************************************ */
7878
VectorValues VectorValues::Zero(const VectorValues& other)
7979
{
8080
VectorValues result;
@@ -87,7 +87,7 @@ namespace gtsam {
8787
return result;
8888
}
8989

90-
/* ************************************************************************* */
90+
/* ************************************************************************ */
9191
VectorValues::iterator VectorValues::insert(const std::pair<Key, Vector>& key_value) {
9292
std::pair<iterator, bool> result = values_.insert(key_value);
9393
if(!result.second)
@@ -97,7 +97,7 @@ namespace gtsam {
9797
return result.first;
9898
}
9999

100-
/* ************************************************************************* */
100+
/* ************************************************************************ */
101101
void VectorValues::update(const VectorValues& values)
102102
{
103103
iterator hint = begin();
@@ -115,7 +115,7 @@ namespace gtsam {
115115
}
116116
}
117117

118-
/* ************************************************************************* */
118+
/* ************************************************************************ */
119119
void VectorValues::insert(const VectorValues& values)
120120
{
121121
size_t originalSize = size();
@@ -124,14 +124,14 @@ namespace gtsam {
124124
throw invalid_argument("Requested to insert a VectorValues into another VectorValues that already contains one or more of its keys.");
125125
}
126126

127-
/* ************************************************************************* */
127+
/* ************************************************************************ */
128128
void VectorValues::setZero()
129129
{
130130
for(Vector& v: values_ | map_values)
131131
v.setZero();
132132
}
133133

134-
/* ************************************************************************* */
134+
/* ************************************************************************ */
135135
GTSAM_EXPORT ostream& operator<<(ostream& os, const VectorValues& v) {
136136
// Change print depending on whether we are using TBB
137137
#ifdef GTSAM_USE_TBB
@@ -150,15 +150,15 @@ namespace gtsam {
150150
return os;
151151
}
152152

153-
/* ************************************************************************* */
153+
/* ************************************************************************ */
154154
void VectorValues::print(const string& str,
155155
const KeyFormatter& formatter) const {
156156
cout << str << ": " << size() << " elements\n";
157157
cout << key_formatter(formatter) << *this;
158158
cout.flush();
159159
}
160160

161-
/* ************************************************************************* */
161+
/* ************************************************************************ */
162162
bool VectorValues::equals(const VectorValues& x, double tol) const {
163163
if(this->size() != x.size())
164164
return false;
@@ -170,7 +170,7 @@ namespace gtsam {
170170
return true;
171171
}
172172

173-
/* ************************************************************************* */
173+
/* ************************************************************************ */
174174
Vector VectorValues::vector() const {
175175
// Count dimensions
176176
DenseIndex totalDim = 0;
@@ -187,7 +187,7 @@ namespace gtsam {
187187
return result;
188188
}
189189

190-
/* ************************************************************************* */
190+
/* ************************************************************************ */
191191
Vector VectorValues::vector(const Dims& keys) const
192192
{
193193
// Count dimensions
@@ -203,12 +203,12 @@ namespace gtsam {
203203
return result;
204204
}
205205

206-
/* ************************************************************************* */
206+
/* ************************************************************************ */
207207
void VectorValues::swap(VectorValues& other) {
208208
this->values_.swap(other.values_);
209209
}
210210

211-
/* ************************************************************************* */
211+
/* ************************************************************************ */
212212
namespace internal
213213
{
214214
bool structureCompareOp(const boost::tuple<VectorValues::value_type,
@@ -219,14 +219,14 @@ namespace gtsam {
219219
}
220220
}
221221

222-
/* ************************************************************************* */
222+
/* ************************************************************************ */
223223
bool VectorValues::hasSameStructure(const VectorValues other) const
224224
{
225225
return accumulate(combine(*this, other)
226226
| transformed(internal::structureCompareOp), true, logical_and<bool>());
227227
}
228228

229-
/* ************************************************************************* */
229+
/* ************************************************************************ */
230230
double VectorValues::dot(const VectorValues& v) const
231231
{
232232
if(this->size() != v.size())
@@ -244,12 +244,12 @@ namespace gtsam {
244244
return result;
245245
}
246246

247-
/* ************************************************************************* */
247+
/* ************************************************************************ */
248248
double VectorValues::norm() const {
249249
return std::sqrt(this->squaredNorm());
250250
}
251251

252-
/* ************************************************************************* */
252+
/* ************************************************************************ */
253253
double VectorValues::squaredNorm() const {
254254
double sumSquares = 0.0;
255255
using boost::adaptors::map_values;
@@ -258,7 +258,7 @@ namespace gtsam {
258258
return sumSquares;
259259
}
260260

261-
/* ************************************************************************* */
261+
/* ************************************************************************ */
262262
VectorValues VectorValues::operator+(const VectorValues& c) const
263263
{
264264
if(this->size() != c.size())
@@ -278,13 +278,13 @@ namespace gtsam {
278278
return result;
279279
}
280280

281-
/* ************************************************************************* */
281+
/* ************************************************************************ */
282282
VectorValues VectorValues::add(const VectorValues& c) const
283283
{
284284
return *this + c;
285285
}
286286

287-
/* ************************************************************************* */
287+
/* ************************************************************************ */
288288
VectorValues& VectorValues::operator+=(const VectorValues& c)
289289
{
290290
if(this->size() != c.size())
@@ -301,13 +301,13 @@ namespace gtsam {
301301
return *this;
302302
}
303303

304-
/* ************************************************************************* */
304+
/* ************************************************************************ */
305305
VectorValues& VectorValues::addInPlace(const VectorValues& c)
306306
{
307307
return *this += c;
308308
}
309309

310-
/* ************************************************************************* */
310+
/* ************************************************************************ */
311311
VectorValues& VectorValues::addInPlace_(const VectorValues& c)
312312
{
313313
for(const_iterator j2 = c.begin(); j2 != c.end(); ++j2) {
@@ -320,7 +320,7 @@ namespace gtsam {
320320
return *this;
321321
}
322322

323-
/* ************************************************************************* */
323+
/* ************************************************************************ */
324324
VectorValues VectorValues::operator-(const VectorValues& c) const
325325
{
326326
if(this->size() != c.size())
@@ -340,13 +340,13 @@ namespace gtsam {
340340
return result;
341341
}
342342

343-
/* ************************************************************************* */
343+
/* ************************************************************************ */
344344
VectorValues VectorValues::subtract(const VectorValues& c) const
345345
{
346346
return *this - c;
347347
}
348348

349-
/* ************************************************************************* */
349+
/* ************************************************************************ */
350350
VectorValues operator*(const double a, const VectorValues &v)
351351
{
352352
VectorValues result;
@@ -359,26 +359,57 @@ namespace gtsam {
359359
return result;
360360
}
361361

362-
/* ************************************************************************* */
362+
/* ************************************************************************ */
363363
VectorValues VectorValues::scale(const double a) const
364364
{
365365
return a * *this;
366366
}
367367

368-
/* ************************************************************************* */
368+
/* ************************************************************************ */
369369
VectorValues& VectorValues::operator*=(double alpha)
370370
{
371371
for(Vector& v: *this | map_values)
372372
v *= alpha;
373373
return *this;
374374
}
375375

376-
/* ************************************************************************* */
376+
/* ************************************************************************ */
377377
VectorValues& VectorValues::scaleInPlace(double alpha)
378378
{
379379
return *this *= alpha;
380380
}
381381

382-
/* ************************************************************************* */
382+
/* ************************************************************************ */
383+
string VectorValues::html(const KeyFormatter& keyFormatter) const {
384+
stringstream ss;
385+
386+
// Print out preamble.
387+
ss << "<div>\n<table class='VectorValues'>\n <thead>\n";
388+
389+
// Print out header row.
390+
ss << " <tr><th>Variable</th><th>value</th></tr>\n";
391+
392+
// Finish header and start body.
393+
ss << " </thead>\n <tbody>\n";
394+
395+
// Print out all rows.
396+
#ifdef GTSAM_USE_TBB
397+
// TBB uses un-ordered map, so inefficiently order them:
398+
std::map<Key, Vector> ordered;
399+
for (const auto& kv : *this) ordered.emplace(kv);
400+
for (const auto& kv : ordered) {
401+
#else
402+
for (const auto& kv : *this) {
403+
#endif
404+
ss << " <tr>";
405+
ss << "<th>" << keyFormatter(kv.first) << "</th><td>"
406+
<< kv.second.transpose() << "</td>";
407+
ss << "</tr>\n";
408+
}
409+
ss << " </tbody>\n</table>\n</div>";
410+
return ss.str();
411+
}
412+
413+
/* ************************************************************************ */
383414

384415
} // \namespace gtsam

0 commit comments

Comments
 (0)