-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathauto_diff.h
More file actions
140 lines (125 loc) · 4.99 KB
/
auto_diff.h
File metadata and controls
140 lines (125 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Copyright 2026 Julien Michot.
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <tinyopt/cost.h>
#include <tinyopt/log.h>
#include <tinyopt/math.h>
#include <tinyopt/diff/jet.h>
namespace tinyopt::diff {
/// Return the function `f` residuals with the jacobian d f(x)/d(x) around `x` using
/// automatic differentiation
template <typename X_t, typename CostOrResFunc>
auto Eval(const X_t &x, const CostOrResFunc &cost_or_res_func, bool check_residuals = true) {
using ptrait = traits::params_trait<X_t>;
using Scalar = typename ptrait::Scalar;
constexpr Index Dims = ptrait::Dims;
constexpr bool is_userdef_type =
!std::is_floating_point_v<X_t> && !traits::is_matrix_or_array_v<X_t>;
const Index dims = traits::DynDims(x);
// Construct the Jet
using Jet = diff::Jet<Scalar, Dims>;
// Construct the Jet for scalar or Matrix, so {Jet, Vector<Jet, N> ot nullptr_t}
auto x_jet = ptrait::template cast<Jet>(x);
if constexpr (is_userdef_type) { // X is user defined object
Vector<Jet, Dims> dx_jet = Vector<Jet, Dims>::Zero(dims);
for (Index i = 0; i < dims; ++i) {
// If X size at compile time is not known, we need to set the Jet.v
if constexpr (Dims == Dynamic) dx_jet[i].v = Vector<Scalar, Dynamic>::Zero(dims);
dx_jet[i].v[i] = 1;
}
using ptrait_jet = traits::params_trait<std::decay_t<decltype(x_jet)>>;
ptrait_jet::PlusEq(x_jet, dx_jet);
} else if constexpr (std::is_floating_point_v<X_t>) { // X is scalar
x_jet.v[0] = 1;
} else { // X is a Vector or Matrix
// Set Jet's v
for (int c = 0; c < x.cols(); ++c) {
for (int r = 0; r < x.rows(); ++r) {
const auto i = r + c * x.rows();
if constexpr (Dims == Dynamic) x_jet(r, c).v = Vector<Scalar, Dims>::Zero(dims);
x_jet(r, c).v[i] = 1;
}
}
}
// Support different function signatures
auto fg = [&](const auto &x) {
std::nullptr_t nul;
if constexpr (std::is_invocable_v<CostOrResFunc, const X_t &>)
return cost_or_res_func(x);
else if constexpr (std::is_invocable_v<CostOrResFunc, const X_t &, std::nullptr_t &>)
return cost_or_res_func(x, nul);
else if constexpr (std::is_invocable_v<CostOrResFunc, const X_t &, std::nullptr_t &,
std::nullptr_t &>)
return cost_or_res_func(x, nul, nul);
else { // likely a SparseMatrix<Scalar> hessian
SparseMatrix<Scalar> H;
H.resize(dims, dims);
return cost_or_res_func(x, nul, H);
}
};
// Retrieve the residuals
const auto res = fg(x_jet);
using ResType = typename std::decay_t<decltype(res)>;
// Make sure the return type is either a Jet or Matrix/Array<Jet>
static_assert(
traits::is_jet_type_v<ResType> ||
(traits::is_matrix_or_array_v<ResType> && traits::is_jet_type_v<typename ResType::Scalar>));
if constexpr (!traits::is_matrix_or_array_v<ResType>) { // One residual
return std::make_pair(res.a, res.v.transpose().eval());
} else {
constexpr int ResDims = traits::params_trait<ResType>::Dims;
const Index res_dims = traits::DynDims(res);
Matrix<Scalar, ResDims, Dims> J(res_dims, dims);
Vector<Scalar, ResDims> res_f(res.size());
if constexpr (traits::is_matrix_or_array_v<ResType>) {
if constexpr (ResType::ColsAtCompileTime != 1) { // Matrix or Vector with dynamic size
for (int c = 0; c < res.cols(); ++c)
for (int r = 0; r < res.rows(); ++r) {
const Index i = r + c * res.rows();
res_f[i] = res(r, c).a;
if constexpr (Dims == Dynamic) {
if (check_residuals && res(r, c).v.size() != dims) {
TINYOPT_LOG("⚠️ Residual ({},{}) is not connected to the parameters", r, c);
J.row(i).setZero();
continue;
}
}
J.row(i) = res(r, c).v;
}
} else { // Vector
for (Index i = 0; i < res_dims; ++i) {
res_f[i] = res[i].a;
if constexpr (Dims == Dynamic) {
if (check_residuals && res[i].v.size() != dims) {
TINYOPT_LOG("⚠️ Residual #{} is not connected to the parameters", i);
J.row(i).setZero();
continue;
}
}
J.row(i) = res[i].v;
}
}
} else { // scalar
for (Index i = 0; i < res_dims; ++i) {
res_f[i] = res.a;
if constexpr (Dims == Dynamic) {
if (check_residuals && res.v.size() != dims) {
TINYOPT_LOG("⚠️ Residual is not connected to the parameters");
J.row(i).setZero();
continue;
}
}
J.row(i) = res.v;
}
}
return std::make_pair(res_f, J);
}
}
/// Estimate the jacobian of d f(x)/d(x) around `x` using automatic
/// differentiation
template <typename X_t, typename CostOrResFunc>
auto CalculateJac(const X_t &x, const CostOrResFunc &cost_func, bool check_residuals = true) {
const auto &[res, J] = Eval(x, cost_func, check_residuals);
return J;
}
} // namespace tinyopt::diff