Skip to content

Commit f746a69

Browse files
committed
Add QR, LU, SVD linear solvers
1 parent d4293c4 commit f746a69

File tree

9 files changed

+345
-23
lines changed

9 files changed

+345
-23
lines changed

benchmarks/dense.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ static const bool enable_log = false;
2828
TEST_CASE("Float", "[benchmark][fixed][scalar]") {
2929
auto loss = [](const auto &x) { return x * x - 2.0f; };
3030
Options options = CreateOptions(enable_log);
31-
options.solver.use_ldlt = false;
31+
options.solver.linear_solver = tinyopt::solvers::Options2::LinearSolver::Inverse;
3232
options.solver.log.print_failure = true;
3333
BENCHMARK("√2") {
3434
float x = Vec1::Random()[0];
@@ -40,7 +40,7 @@ TEST_CASE("Float", "[benchmark][fixed][scalar]") {
4040
TEST_CASE("Double", "[benchmark][fixed][scalar]") {
4141
auto loss = [](const auto &x) { return x * x - 2.0; };
4242
Options options = CreateOptions(enable_log);
43-
options.solver.use_ldlt = false;
43+
options.solver.linear_solver = tinyopt::solvers::Options2::LinearSolver::Inverse;
4444
static StatCounter<double> counter;
4545
BENCHMARK("√2") {
4646
double x = Vec1::Random()[0]; // 0.480009157900 fails to converge

cmake/Options.cmake

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ option(TINYOPT_BUILD_PACKAGES "Build packages" OFF)
2828
option(TINYOPT_BUILD_DOCS "Build documentation" OFF)
2929

3030

31-
# Adding Definitions
31+
# Linear Solvers
32+
option(TINYOPT_BUILD_SOLVER_LDLT "Build Cholesky/LDLT linear solver" ON)
33+
option(TINYOPT_BUILD_SOLVER_QR "Build (col-piv) QR linear solver" ON)
34+
option(TINYOPT_BUILD_SOLVER_LU "Build (partial) LU linear solver" ON)
35+
option(TINYOPT_BUILD_SOLVER_SVD "Build SVD linear solver" ON)
36+
37+
# Adding Definitions # TODO: use target_compile_definitions()
3238
if (NOT TINYOPT_ENABLE_FORMATTERS)
3339
add_definitions(-DTINYOPT_NO_FORMATTERS=1)
3440
endif ()
@@ -38,3 +44,16 @@ endif ()
3844
if (TINYOPT_DISABLE_NUMDIFF)
3945
add_definitions(-DTINYOPT_DISABLE_NUMDIFF=1)
4046
endif ()
47+
# Linear solvers
48+
if (TINYOPT_BUILD_SOLVER_LDLT)
49+
add_definitions(-DTINYOPT_BUILD_SOLVER_LDLT=1)
50+
endif ()
51+
if (TINYOPT_BUILD_SOLVER_QR)
52+
add_definitions(-DTINYOPT_BUILD_SOLVER_QR=1)
53+
endif ()
54+
if (TINYOPT_BUILD_SOLVER_LU)
55+
add_definitions(-DTINYOPT_BUILD_SOLVER_LU=1)
56+
endif ()
57+
if (TINYOPT_BUILD_SOLVER_SVD)
58+
add_definitions(-DTINYOPT_BUILD_SOLVER_SVD=1)
59+
endif ()

include/tinyopt/math.h

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <optional>
77

88
#include <tinyopt/types.h>
9+
#include <Eigen/src/SVD/JacobiSVD.h>
910

1011
namespace tinyopt {
1112

@@ -276,6 +277,110 @@ std::optional<Vector<Scalar, RowsAtCompileTime>> SolveLDLT(
276277
return X;
277278
}
278279

280+
/**
281+
* @brief Solves the linear system A * X = B for X using LU decomposition.
282+
*
283+
* @tparam Derived The type of the matrix A.
284+
* @tparam Derived2 The type of the vector B.
285+
*
286+
* @param A The coefficient matrix A.
287+
* @param b The right-hand side vector B.
288+
*
289+
* @return An `std::optional` containing the solution vector X if the system is solvable, or
290+
* `std::nullopt` otherwise.
291+
*/
292+
template <typename Derived, typename Derived2>
293+
std::optional<Vector<typename Derived::Scalar, Derived::RowsAtCompileTime>> SolveLU(
294+
const MatrixBase<Derived> &A, const MatrixBase<Derived2> &b) {
295+
auto lu = A.partialPivLu();
296+
if (lu.determinant() != 0) {
297+
return lu.solve(b);
298+
}
299+
return std::nullopt;
300+
}
301+
302+
/**
303+
* @brief Solves the linear system A * X = B for X using QR decomposition.
304+
*
305+
* @tparam Derived The type of the matrix A.
306+
* @tparam Derived2 The type of the vector B.
307+
*
308+
* @param A The coefficient matrix A.
309+
* @param b The right-hand side vector B.
310+
*
311+
* @return An `std::optional` containing the solution vector X.
312+
*/
313+
template <typename Derived, typename Derived2>
314+
std::optional<Vector<typename Derived::Scalar, Derived::RowsAtCompileTime>> SolveQR(
315+
const MatrixBase<Derived> &A, const MatrixBase<Derived2> &b) {
316+
return A.colPivHouseholderQr().solve(b);
317+
}
318+
319+
/**
320+
* @brief Solves the linear system A * X = B for X using SVD decomposition.
321+
*
322+
* @tparam Derived The type of the matrix A.
323+
* @tparam Derived2 The type of the vector B.
324+
*
325+
* @param A The coefficient matrix A.
326+
* @param b The right-hand side vector B.
327+
*
328+
* @return An `std::optional` containing the solution vector X.
329+
*/
330+
template <typename Derived, typename Derived2>
331+
std::optional<Vector<typename Derived::Scalar, Derived::RowsAtCompileTime>> SolveSVD(
332+
const MatrixBase<Derived> &A, const MatrixBase<Derived2> &b) {
333+
return A.jacobiSvd(Eigen::ComputeThinU | Eigen::ComputeThinV).solve(b);
334+
}
335+
336+
337+
/**
338+
* @brief Solves the sparse linear system A * X = B for X using LU decomposition.
339+
*
340+
* @tparam Scalar The scalar type.
341+
* @tparam RowsAtCompileTime The compile-time number of rows of vector B.
342+
*
343+
* @param A The sparse coefficient matrix A.
344+
* @param b The right-hand side vector B.
345+
*
346+
* @return An `std::optional` containing the solution vector X if the system is solvable, or
347+
* `std::nullopt` otherwise.
348+
*/
349+
template <typename Scalar, int RowsAtCompileTime = Dynamic>
350+
std::optional<Vector<Scalar, RowsAtCompileTime>> SolveLU(
351+
const SparseMatrix<Scalar> &A, const Vector<Scalar, RowsAtCompileTime> &b) {
352+
Eigen::SparseLU<SparseMatrix<Scalar>> solver;
353+
solver.compute(A);
354+
if (solver.info() != Eigen::Success) return std::nullopt;
355+
auto X = solver.solve(b);
356+
if (solver.info() != Eigen::Success) return std::nullopt;
357+
return X;
358+
}
359+
360+
/**
361+
* @brief Solves the sparse linear system A * X = B for X using QR decomposition.
362+
*
363+
* @tparam Scalar The scalar type.
364+
* @tparam RowsAtCompileTime The compile-time number of rows of vector B.
365+
*
366+
* @param A The sparse coefficient matrix A.
367+
* @param b The right-hand side vector B.
368+
*
369+
* @return An `std::optional` containing the solution vector X if the system is solvable, or
370+
* `std::nullopt` otherwise.
371+
*/
372+
template <typename Scalar, int RowsAtCompileTime = Dynamic>
373+
std::optional<Vector<Scalar, RowsAtCompileTime>> SolveQR(
374+
const SparseMatrix<Scalar> &A, const Vector<Scalar, RowsAtCompileTime> &b) {
375+
Eigen::SparseQR<SparseMatrix<Scalar>, Eigen::COLAMDOrdering<int>> solver;
376+
solver.compute(A);
377+
if (solver.info() != Eigen::Success) return std::nullopt;
378+
auto X = solver.solve(b);
379+
if (solver.info() != Eigen::Success) return std::nullopt;
380+
return X;
381+
}
382+
383+
279384
/// Integer square root function for positive integers
280385
/// Will return `N` for negative or 0 values
281386
constexpr inline int SQRT(int N) {

include/tinyopt/solvers/gn.h

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ class SolverGN
4040
using Options = nlls::gn::SolverOptions;
4141

4242
explicit SolverGN(const Options &options = {}) : Base(options), options_{options} {
43-
// Sparse matrix must use LDLT
43+
// Inverse is not supported for sparse matrices
4444
if constexpr (traits::is_sparse_matrix_v<H_t>) {
45-
if (!options.use_ldlt) TINYOPT_LOG("Warning: LDLT must be used with Sparse Matrices");
45+
if (options.linear_solver == Options::LinearSolver::Inverse)
46+
TINYOPT_LOG("Warning: Inverse is not supported with Sparse Matrices");
4647
}
4748
}
4849

@@ -149,7 +150,7 @@ class SolverGN
149150

150151
// Fill the lower part if H if needed
151152
{
152-
if (!options_.H_is_full && !options_.use_ldlt) {
153+
if (!options_.H_is_full && options_.linear_solver != Options::LinearSolver::LDLT) {
153154
H_.template triangularView<Lower>() = H_.template triangularView<Upper>().transpose();
154155
}
155156
}
@@ -160,17 +161,49 @@ class SolverGN
160161
inline std::optional<Vector<Scalar, Dims>> Solve() const override {
161162
if (!this->cost().isValid()) return std::nullopt;
162163

164+
std::optional<Vector<Scalar, Dims>> dx_;
165+
163166
// Solver linear system
164-
if (options_.use_ldlt || traits::is_sparse_matrix_v<H_t>) {
165-
const auto dx_ = tinyopt::SolveLDLT(H_, -grad_);
166-
if (dx_) return dx_; // Hopefully not a copy...
167-
} else if constexpr (!traits::is_sparse_matrix_v<H_t>) { // Use default inverse
168-
if constexpr (Dims == 1) {
169-
if (H_(0, 0) > FloatEpsilon<Scalar>()) return -H_.inverse() * grad_;
170-
return Vector<Scalar, Dims>::Zero(grad_.size());
171-
} else
172-
return -H_.inverse() * grad_;
167+
switch (options_.linear_solver) {
168+
#if TINYOPT_BUILD_SOLVER_LDLT
169+
case Options::LinearSolver::LDLT:
170+
dx_ = tinyopt::SolveLDLT(H_, grad_);
171+
break;
172+
#endif
173+
#if TINYOPT_BUILD_SOLVER_LU
174+
case Options::LinearSolver::LU:
175+
dx_ = tinyopt::SolveLU(H_, grad_);
176+
break;
177+
#endif
178+
#if TINYOPT_BUILD_SOLVER_QR
179+
case Options::LinearSolver::QR:
180+
dx_ = tinyopt::SolveQR(H_, grad_);
181+
break;
182+
#endif
183+
#if TINYOPT_BUILD_SOLVER_SVD
184+
case Options::LinearSolver::SVD:
185+
dx_ = tinyopt::SolveSVD(H_, grad_);
186+
break;
187+
#endif
188+
case Options::LinearSolver::Inverse:
189+
if constexpr (!traits::is_sparse_matrix_v<H_t>) { // Use default inverse
190+
if constexpr (Dims == 1) {
191+
if (H_(0, 0) > FloatEpsilon<Scalar>()) dx_ = H_.inverse() * grad_;
192+
} else {
193+
dx_ = H_.inverse() * grad_;
194+
}
195+
}
196+
break;
197+
default:
198+
TINYOPT_LOG("❌ Unsupported linear solver");
199+
break;
200+
}
201+
202+
if (dx_) {
203+
dx_.value() = (-dx_.value()).eval();
204+
return dx_;
173205
}
206+
174207
// Log on failure
175208
if (options_.log.enable && options_.log.print_failure) {
176209
TINYOPT_LOG("❌ Failed solve linear system");

include/tinyopt/solvers/lm.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ class SolverLM : public tinyopt::solvers::SolverGN<Hessian_t> {
5252
using Options = nlls::lm::SolverOptions;
5353

5454
explicit SolverLM(const Options &options = {}) : Base(options), options_{options} {
55-
// Sparse matrix must use LDLT
55+
// Inverse is not supported for sparse matrices
5656
if constexpr (traits::is_sparse_matrix_v<H_t>) {
57-
if (!options.use_ldlt) TINYOPT_LOG("Warning: LDLT must be used with Sparse Matrices");
57+
if (options.linear_solver == Options::LinearSolver::Inverse)
58+
TINYOPT_LOG("Warning: Inverse is not supported with Sparse Matrices");
5859
}
5960
reset();
6061
}
@@ -104,7 +105,7 @@ class SolverLM : public tinyopt::solvers::SolverGN<Hessian_t> {
104105

105106
// Fill the lower part if H if needed
106107
if constexpr (!traits::is_sparse_matrix_v<H_t>) {
107-
if (!options_.H_is_full && !options_.use_ldlt) {
108+
if (!options_.H_is_full && options_.linear_solver != Options::LinearSolver::LDLT) {
108109
this->H_.template triangularView<Lower>() =
109110
this->H_.template triangularView<Upper>().transpose();
110111
}

include/tinyopt/solvers/options.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct Options1 {
1919
* @name Cost scaling options (mostly for NLLS solvers really)
2020
* @{
2121
*/
22-
struct CostScaling {
22+
struct CostScaling {
2323
bool use_squared_norm = true; ///< Use squared norm instead of norm (faster)
2424
bool downscale_by_2 = false; ///< Rescale the cost by 0.5
2525
/// Normalize the final error by the number of residuals (after use_squared_norm)
@@ -29,7 +29,7 @@ struct Options1 {
2929
/** @} */
3030

3131
struct {
32-
bool enable = true; // Enable solver logging
32+
bool enable = true; // Enable solver logging
3333
bool print_failure = false; // Log when a failure to solve the linear system happens
3434
} log;
3535
};
@@ -46,8 +46,20 @@ struct Options2 : Options1 {
4646
* @{
4747
*/
4848

49-
bool use_ldlt = true; ///< If not, will use H.inverse() without any checks on invertibility
50-
///< except for Dims==1
49+
enum class LinearSolver {
50+
LDLT,
51+
LU,
52+
QR,
53+
SVD,
54+
Inverse
55+
}; ///< Linear solver to use for the linear system
56+
LinearSolver linear_solver =
57+
#ifdef TINYOPT_BUILD_SOLVER_LDLT
58+
LinearSolver::LDLT;
59+
#else
60+
LinearSolver::Inverse;
61+
#endif
62+
5163
bool H_is_full = true; ///< Specify if H is only Upper triangularly or fully filled
5264

5365
/** @} */

tests/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ add_executable(tinyopt_test_solvers solvers.cpp)
2323
target_link_libraries(tinyopt_test_solvers PRIVATE ${THIRDPARTY_TEST_LIBS} tinyopt)
2424
add_test_target(tinyopt_test_solvers)
2525

26+
add_executable(tinyopt_test_linear_solvers linear_solvers.cpp)
27+
target_link_libraries(tinyopt_test_linear_solvers PRIVATE ${THIRDPARTY_TEST_LIBS} tinyopt)
28+
add_test_target(tinyopt_test_linear_solvers)
29+
2630
add_executable(tinyopt_test_optimizers optimizers.cpp)
2731
target_link_libraries(tinyopt_test_optimizers PRIVATE ${THIRDPARTY_TEST_LIBS} tinyopt)
2832
add_test_target(tinyopt_test_optimizers)

0 commit comments

Comments
 (0)