Skip to content

Commit 855853f

Browse files
committed
[gar] Move impl of lqrComputeKktError to hxx file, expose to Python
1 parent a53014c commit 855853f

File tree

3 files changed

+107
-98
lines changed

3 files changed

+107
-98
lines changed

bindings/python/src/gar/expose-utils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "aligator/python/fwd.hpp"
22
#include "aligator/gar/utils.hpp"
33

4+
#include <eigenpy/std-array.hpp>
5+
46
namespace aligator::python {
57
using namespace gar;
68

@@ -33,5 +35,10 @@ void exposeGarUtils() {
3335
"Create or update a sparse matrix from an LQRProblem.");
3436

3537
bp::def("lqrInitializeSolution", lqr_sol_initialize_wrap, ("problem"_a));
38+
39+
bp::def("lqrComputeKktError", lqrComputeKktError<Scalar>,
40+
("problem"_a, "xs", "us", "vs", "lbdas", "mudyn", "mueq", "theta",
41+
"verbose"_a = false),
42+
"Compute the KKT residual of the LQR problem.");
3643
}
3744
} // namespace aligator::python

gar/include/aligator/gar/utils.hpp

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -59,104 +59,6 @@ std::array<Scalar, 3> lqrComputeKktError(
5959
const std::optional<typename math_types<Scalar>::ConstVectorRef> &theta_,
6060
bool verbose = false);
6161

62-
template <typename Scalar>
63-
std::array<Scalar, 3> lqrComputeKktError(
64-
const LQRProblemTpl<Scalar> &problem,
65-
boost::span<const typename math_types<Scalar>::VectorXs> xs,
66-
boost::span<const typename math_types<Scalar>::VectorXs> us,
67-
boost::span<const typename math_types<Scalar>::VectorXs> vs,
68-
boost::span<const typename math_types<Scalar>::VectorXs> lbdas,
69-
const Scalar mudyn, const Scalar mueq,
70-
const std::optional<typename math_types<Scalar>::ConstVectorRef> &theta_,
71-
bool verbose) {
72-
fmt::print("[{}] ", __func__);
73-
uint N = (uint)problem.horizon();
74-
using VectorXs = typename math_types<Scalar>::VectorXs;
75-
using KnotType = LQRKnotTpl<Scalar>;
76-
77-
Scalar dynErr = 0.;
78-
Scalar cstErr = 0.;
79-
Scalar dualErr = 0.;
80-
Scalar dNorm;
81-
82-
VectorXs _dyn;
83-
VectorXs _cst;
84-
VectorXs _gx;
85-
VectorXs _gu;
86-
VectorXs _gt;
87-
88-
// initial stage
89-
{
90-
_dyn = problem.g0 + problem.G0 * xs[0] - mudyn * lbdas[0];
91-
dNorm = math::infty_norm(_dyn);
92-
dynErr = std::max(dynErr, dNorm);
93-
if (verbose)
94-
fmt::print("d0 = {:.3e} \n", dNorm);
95-
}
96-
for (uint t = 0; t <= N; t++) {
97-
const KnotType &knot = problem.stages[t];
98-
auto _Str = knot.S.transpose();
99-
100-
if (verbose)
101-
fmt::print("[{: >2d}] ", t);
102-
_gx.setZero(knot.nx);
103-
_gu.setZero(knot.nu);
104-
_gt.setZero(knot.nth);
105-
106-
_cst = knot.C * xs[t] + knot.d - mueq * vs[t];
107-
_gx.noalias() = knot.q + knot.Q * xs[t] + knot.C.transpose() * vs[t];
108-
_gu.noalias() = knot.r + _Str * xs[t] + knot.D.transpose() * vs[t];
109-
110-
if (knot.nu > 0) {
111-
_cst.noalias() += knot.D * us[t];
112-
_gx.noalias() += knot.S * us[t];
113-
_gu.noalias() += knot.R * us[t];
114-
}
115-
116-
if (t == 0) {
117-
_gx += problem.G0.transpose() * lbdas[0];
118-
} else {
119-
auto Et = problem.stages[t - 1].E.transpose();
120-
_gx += Et * lbdas[t];
121-
}
122-
123-
if (t < N) {
124-
_dyn = knot.A * xs[t] + knot.B * us[t] + knot.f + knot.E * xs[t + 1] -
125-
mudyn * lbdas[t + 1];
126-
_gx += knot.A.transpose() * lbdas[t + 1];
127-
_gu += knot.B.transpose() * lbdas[t + 1];
128-
129-
dNorm = math::infty_norm(_dyn);
130-
if (verbose)
131-
fmt::print(" |d| = {:.3e} | ", dNorm);
132-
dynErr = std::max(dynErr, dNorm);
133-
}
134-
135-
if (theta_.has_value()) {
136-
Eigen::Ref<const VectorXs> th = theta_.value();
137-
_gx.noalias() += knot.Gx * th;
138-
_gu.noalias() += knot.Gu * th;
139-
_gt = knot.gamma;
140-
_gt.noalias() += knot.Gx.transpose() * xs[t];
141-
if (knot.nu > 0)
142-
_gt.noalias() += knot.Gu.transpose() * us[t];
143-
_gt.noalias() += knot.Gth * th;
144-
}
145-
146-
Scalar gxNorm = math::infty_norm(_gx);
147-
Scalar guNorm = math::infty_norm(_gu);
148-
Scalar cstNorm = math::infty_norm(_cst);
149-
if (verbose)
150-
fmt::print("|gx| = {:.3e} | |gu| = {:.3e} | |cst| = {:.3e}\n", gxNorm,
151-
guNorm, cstNorm);
152-
153-
dualErr = std::max({dualErr, gxNorm, guNorm});
154-
cstErr = std::max(cstErr, cstNorm);
155-
}
156-
157-
return std::array{dynErr, cstErr, dualErr};
158-
}
159-
16062
/// @brief Fill in a KKT constraint matrix and vector for the given LQ problem
16163
/// with the given dual-regularization parameters @p mudyn and @p mueq.
16264
/// @returns Whether the matrices were successfully allocated.

gar/include/aligator/gar/utils.hxx

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,106 @@ void lqrCreateSparseMatrix(const LQRProblemTpl<Scalar> &problem,
9191
}
9292
}
9393

94+
template <typename Scalar>
95+
std::array<Scalar, 3> lqrComputeKktError(
96+
const LQRProblemTpl<Scalar> &problem,
97+
boost::span<const typename math_types<Scalar>::VectorXs> xs,
98+
boost::span<const typename math_types<Scalar>::VectorXs> us,
99+
boost::span<const typename math_types<Scalar>::VectorXs> vs,
100+
boost::span<const typename math_types<Scalar>::VectorXs> lbdas,
101+
const Scalar mudyn, const Scalar mueq,
102+
const std::optional<typename math_types<Scalar>::ConstVectorRef> &theta_,
103+
bool verbose) {
104+
if (verbose)
105+
fmt::print("[{}] ", __func__);
106+
uint N = (uint)problem.horizon();
107+
assert(xs.size() == N + 1);
108+
using VectorXs = typename math_types<Scalar>::VectorXs;
109+
using KnotType = LQRKnotTpl<Scalar>;
110+
111+
Scalar dynErr = 0.;
112+
Scalar cstErr = 0.;
113+
Scalar dualErr = 0.;
114+
Scalar dNorm;
115+
116+
VectorXs _dyn;
117+
VectorXs _cst;
118+
VectorXs _gx;
119+
VectorXs _gu;
120+
VectorXs _gt;
121+
122+
// initial stage
123+
{
124+
_dyn = problem.g0 + problem.G0 * xs[0] - mudyn * lbdas[0];
125+
dNorm = math::infty_norm(_dyn);
126+
dynErr = std::max(dynErr, dNorm);
127+
if (verbose)
128+
fmt::print("d0 = {:.3e} \n", dNorm);
129+
}
130+
for (uint t = 0; t <= N; t++) {
131+
const KnotType &knot = problem.stages[t];
132+
auto _Str = knot.S.transpose();
133+
134+
if (verbose)
135+
fmt::print("[{: >2d}] ", t);
136+
_gx.setZero(knot.nx);
137+
_gu.setZero(knot.nu);
138+
_gt.setZero(knot.nth);
139+
140+
_cst = knot.C * xs[t] + knot.d - mueq * vs[t];
141+
_gx.noalias() = knot.q + knot.Q * xs[t] + knot.C.transpose() * vs[t];
142+
_gu.noalias() = knot.r + _Str * xs[t] + knot.D.transpose() * vs[t];
143+
144+
if (knot.nu > 0) {
145+
_cst.noalias() += knot.D * us[t];
146+
_gx.noalias() += knot.S * us[t];
147+
_gu.noalias() += knot.R * us[t];
148+
}
149+
150+
if (t == 0) {
151+
_gx += problem.G0.transpose() * lbdas[0];
152+
} else {
153+
auto Et = problem.stages[t - 1].E.transpose();
154+
_gx += Et * lbdas[t];
155+
}
156+
157+
if (t < N) {
158+
_dyn = knot.A * xs[t] + knot.B * us[t] + knot.f + knot.E * xs[t + 1] -
159+
mudyn * lbdas[t + 1];
160+
_gx += knot.A.transpose() * lbdas[t + 1];
161+
_gu += knot.B.transpose() * lbdas[t + 1];
162+
163+
dNorm = math::infty_norm(_dyn);
164+
if (verbose)
165+
fmt::print(" |d| = {:.3e} | ", dNorm);
166+
dynErr = std::max(dynErr, dNorm);
167+
}
168+
169+
if (theta_.has_value()) {
170+
Eigen::Ref<const VectorXs> th = theta_.value();
171+
_gx.noalias() += knot.Gx * th;
172+
_gu.noalias() += knot.Gu * th;
173+
_gt = knot.gamma;
174+
_gt.noalias() += knot.Gx.transpose() * xs[t];
175+
if (knot.nu > 0)
176+
_gt.noalias() += knot.Gu.transpose() * us[t];
177+
_gt.noalias() += knot.Gth * th;
178+
}
179+
180+
Scalar gxNorm = math::infty_norm(_gx);
181+
Scalar guNorm = math::infty_norm(_gu);
182+
Scalar cstNorm = math::infty_norm(_cst);
183+
if (verbose)
184+
fmt::print("|gx| = {:.3e} | |gu| = {:.3e} | |cst| = {:.3e}\n", gxNorm,
185+
guNorm, cstNorm);
186+
187+
dualErr = std::max({dualErr, gxNorm, guNorm});
188+
cstErr = std::max(cstErr, cstNorm);
189+
}
190+
191+
return std::array{dynErr, cstErr, dualErr};
192+
}
193+
94194
template <typename Scalar>
95195
bool lqrDenseMatrix(const LQRProblemTpl<Scalar> &problem, Scalar mudyn,
96196
Scalar mueq, typename math_types<Scalar>::MatrixXs &mat,

0 commit comments

Comments
 (0)