|
4 | 4 | #pragma once |
5 | 5 |
|
6 | 6 | #include <tinyopt/optimizers/optimizer.h> |
| 7 | +#include <tinyopt/optimizers/options.h> |
| 8 | + |
| 9 | +#include <tinyopt/optimizers/optimizers.h> |
| 10 | +#include "tinyopt/log.h" |
7 | 11 |
|
8 | 12 | namespace tinyopt { |
9 | 13 |
|
10 | 14 | /// Simplest interface to optimize `x` and minimize residuals (loss function). |
11 | 15 | /// Internally call the optimizer and run the optimization. |
12 | | -template <typename Optimizer, typename X_t, typename Res_t> |
13 | | -inline auto Optimize(X_t &x, const Res_t &func, const typename Optimizer::Options &options = {}) { |
14 | | - Optimizer optimizer(options); |
15 | | - return optimizer(x, func); |
| 16 | +template <typename T, typename Func> |
| 17 | +inline Output Optimize(T &x, const Func &func, const Options &options = {}) { |
| 18 | + // Detect Scalar, supporting at most one nesting level |
| 19 | + using Scalar = std::conditional_t< |
| 20 | + std::is_scalar_v<typename traits::params_trait<T>::Scalar>, |
| 21 | + typename traits::params_trait<T>::Scalar, |
| 22 | + typename traits::params_trait<typename traits::params_trait<T>::Scalar>::Scalar>; |
| 23 | + static_assert(std::is_scalar_v<Scalar>); |
| 24 | + constexpr Index Dims = traits::params_trait<T>::Dims; |
| 25 | + |
| 26 | + // Detect Hessian Type, if it's dense or sparse |
| 27 | + constexpr bool isDense = |
| 28 | + std::is_invocable_v<Func, const T &> || |
| 29 | + std::is_invocable_v<Func, const T &, Vector<Scalar, Dims> &> || |
| 30 | + std::is_invocable_v<Func, const T &, Vector<Scalar, Dims> &, Matrix<Scalar, Dims, Dims> &>; |
| 31 | + |
| 32 | + using Hessian_t = std::conditional_t<isDense, Matrix<Scalar, Dims, Dims>, SparseMatrix<Scalar>>; |
| 33 | + using Gradient_t = std::conditional_t<isDense, Vector<Scalar, Dims>, SparseMatrix<Scalar>>; |
| 34 | + |
| 35 | + constexpr bool secondOrderValid = !std::is_invocable_v<Func, const T &, Vector<Scalar, Dims> &>; |
| 36 | + |
| 37 | + // Check if this is an unconstrained first order problem |
| 38 | + constexpr bool firstOrderAllowed = !secondOrderValid; |
| 39 | + |
| 40 | + switch (options.solver_type) { |
| 41 | + // Second order methods |
| 42 | + case Options::Solver::GaussNewton: |
| 43 | + if constexpr (secondOrderValid) { |
| 44 | + gn::Optimizer<Hessian_t> optimizer(options); |
| 45 | + return optimizer(x, func); |
| 46 | + } else { |
| 47 | + throw std::invalid_argument( |
| 48 | + "Error: GaussNewton can't be used on this gradient only function"); |
| 49 | + } |
| 50 | + case Options::Solver::LevenbergMarquardt: |
| 51 | + if constexpr (secondOrderValid) { |
| 52 | + lm::Optimizer<Hessian_t> optimizer(options); |
| 53 | + return optimizer(x, func); |
| 54 | + } else { |
| 55 | + throw std::invalid_argument( |
| 56 | + "Error: LevenbergMarquardt can't be used on this gradient only function"); |
| 57 | + } |
| 58 | + // First order methods |
| 59 | + case Options::Solver::GradientDescent: |
| 60 | + if constexpr (std::is_invocable_v<Func, const T &>) { |
| 61 | + using ReturnType = std::invoke_result_t<Func, T>; |
| 62 | + if constexpr (traits::is_scalar_v<ReturnType>) { |
| 63 | + gd::Optimizer<Gradient_t> optimizer(options); |
| 64 | + return optimizer(x, func); |
| 65 | + } else { |
| 66 | + throw std::invalid_argument( |
| 67 | + "Error: cost function must return a scalar for Gradient Descent"); |
| 68 | + } |
| 69 | + } else if constexpr (firstOrderAllowed) { |
| 70 | + gd::Optimizer<Gradient_t> optimizer(options); |
| 71 | + return optimizer(x, func); |
| 72 | + } |
| 73 | + default: |
| 74 | + TINYOPT_LOG("❌ Error: Unknown solver type {}", (int)options.solver_type); |
| 75 | + throw std::invalid_argument("Error: Unknown solver type"); |
| 76 | + } |
16 | 77 | } |
17 | 78 |
|
18 | 79 | } // namespace tinyopt |
0 commit comments