@@ -49,13 +49,33 @@ template <typename VariationalSolver> class QSRPDE {
4949 }
5050
5151 // Functional penalized iterative reweighted least squares
52- template <typename ... LambdaT>
53- requires (std::is_convertible_v<LambdaT, double > && ...) && (sizeof ...(LambdaT) == n_lambda)
54- void fit (double alpha, LambdaT... lambda) {
52+ template <typename ... Args>
53+ requires (sizeof ...(Args) > 0 )
54+ auto fit (double alpha, Args&&... args) {
55+ vector_t lambda (n_lambda);
56+ internals::for_each_index_and_args<sizeof ...(Args)>(
57+ [&]<int Ns_, typename Ts_>(const Ts_& ts) {
58+ if (Ns_ < n_lambda) {
59+ fdapde_static_assert (
60+ std::is_convertible_v<Ts_ FDAPDE_COMMA double >, INVALID_SMOOTHING_PARAMETER_TYPE);
61+ lambda[Ns_] = ts;
62+ }
63+ },
64+ args...);
5565 matrix_t y = y_;
5666 // initialization
5767 solver_.update_response_and_weights (y, vector_t::Ones (n_obs_).asDiagonal ()); // restore solver state
58- solver_.nonparametric_fit ((2 . * lambda)...);
68+ internals::apply_index_pack_and_args<sizeof ...(Args)>(
69+ [&]<typename ... Ts_>(Ts_... ts) {
70+ solver_.nonparametric_fit ([&]() {
71+ if constexpr (Ts_::index < n_lambda) {
72+ return 2 . * ts.value ; // scale smoothing parameter
73+ } else {
74+ return ts.value ;
75+ }
76+ }()...);
77+ },
78+ args...);
5979 mu_ = solver_.Psi () * solver_.f ();
6080 double Jold = std::numeric_limits<double >::max (), Jnew = 0 ;
6181 n_iter_ = 0 ;
@@ -67,24 +87,21 @@ template <typename VariationalSolver> class QSRPDE {
6787 py_ = y - (1 - 2 . * alpha) * abs_res;
6888 // \argmin_{\beta, f} [ 1/n * \norm(W^{1/2} * (y - X * \beta - f_n))^2 + P_{\lambda}(f) ]
6989 solver_.update_response_and_weights (py_, pW_.asDiagonal ());
70- solver_.fit (lambda ...);
90+ solver_.fit (std::forward<Args>(args) ...);
7191 mu_ = fitted ();
7292 // prepare for next iteration
7393 double data_loss = (pW_.cwiseSqrt ().matrix ().asDiagonal () * (py_ - mu_)).squaredNorm () / n_obs_;
7494 Jold = Jnew;
75- Jnew = data_loss + solver_.ftPf (lambda... );
95+ Jnew = data_loss + solver_.ftPf (lambda);
7696 n_iter_++;
7797 }
78- return ;
79- }
80- template <typename ... LambdaT>
81- requires (std::is_convertible_v<LambdaT, double > && ...) && (sizeof ...(LambdaT) == n_lambda)
82- void fit (LambdaT... lambda) {
83- return fit (alpha_, lambda...);
98+ return std::make_pair (solver_.f (), solver_.beta ());
8499 }
100+ template <typename ... Args> auto fit (Args&&... args) { return fit (alpha_, std::forward<Args>(args)...); }
85101 // observers
86102 const vector_t & f () const { return solver_.f (); }
87103 const vector_t & beta () const { return solver_.beta (); }
104+ const vector_t & misfit () const { return solver_.misfit (); }
88105 int n_covs () const { return n_covs_; }
89106 int n_obs () const { return n_obs_; }
90107 double edf (int r = 100 , int seed = random_seed) { return solver_.edf (r, seed); }
0 commit comments