Skip to content

Commit ef32b1c

Browse files
committed
sr, gsr, qsr expose misfit(). sr, gsr, qsr accepting a generic parameter pack as fit() input arguments
1 parent 18dc5d2 commit ef32b1c

File tree

6 files changed

+46
-27
lines changed

6 files changed

+46
-27
lines changed

fdaPDE/src/models/gsr.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,14 @@ template <typename VariationalSolver, typename Distribution> class GSRPDE {
4949
}
5050

5151
// Functional penalized iterative reweighted least squares
52-
template <typename... LambdaT>
53-
requires(std::is_convertible_v<LambdaT, double> && ...)
54-
void fit(LambdaT... lambda) {
52+
template <typename... Args> auto fit(Args&&... args) {
53+
vector_t lambda(n_lambda);
54+
internals::for_each_index_and_args<sizeof...(Args)>([&]<int Ns_, typename Ts_>(const Ts_& ts) {
55+
if (Ns_ < n_lambda) {
56+
fdapde_static_assert(std::is_convertible_v<Ts_ FDAPDE_COMMA double>, INVALID_SMOOTHING_PARAMETER_TYPE);
57+
lambda[Ns_] = ts;
58+
}
59+
});
5560
// initialize mean vector
5661
vector_t y = y_;
5762
solver_.update_response_and_weights(y, vector_t::Ones(n_obs_).asDiagonal()); // restore solver state
@@ -68,20 +73,21 @@ template <typename VariationalSolver, typename Distribution> class GSRPDE {
6873
py_ = G.asDiagonal() * (y - mu_) + distr_.link(mu_);
6974
// \argmin_{\beta, f} [ \norm(W^{1/2} * (y - X * \beta - f_n))^2 + P_{\lambda}(f) ]
7075
solver_.update_response_and_weights(py_, pW_.asDiagonal());
71-
solver_.fit(lambda...);
76+
solver_.fit(std::forward<Args>(args)...);
7277
mu_ = distr_.inv_link(fitted());
7378
// prepare for next iteration
7479
double data_loss =
7580
(distr_.variance(mu_).array().sqrt().inverse().matrix().asDiagonal() * (y - mu_)).squaredNorm() / n_obs_;
7681
Jold = Jnew;
77-
Jnew = data_loss + solver_.ftPf(lambda...);
82+
Jnew = data_loss + solver_.ftPf(lambda);
7883
n_iter_++;
7984
}
80-
return;
85+
return std::make_pair(solver_.f(), solver_.beta());
8186
}
8287
// observers
8388
const vector_t& f() const { return solver_.f(); }
8489
const vector_t& beta() const { return solver_.beta(); }
90+
const vector_t& misfit() const { return solver_.misfit(); }
8591
int n_covs() const { return n_covs_; }
8692
int n_obs() const { return n_obs_; }
8793
double edf(int r = 100, int seed = random_seed) { return solver_.edf(r, seed); }

fdaPDE/src/models/qsr.h

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,33 @@ template <typename VariationalSolver> class QSRPDE {
4949
}
5050

5151
// Functional penalized iterative reweighted least squares
52-
template <typename... LambdaT>
53-
requires(std::is_convertible_v<LambdaT, double> && ...) && (sizeof...(LambdaT) == n_lambda)
54-
void fit(double alpha, LambdaT... lambda) {
52+
template <typename... Args>
53+
requires(sizeof...(Args) > 0)
54+
auto fit(double alpha, Args&&... args) {
55+
vector_t lambda(n_lambda);
56+
internals::for_each_index_and_args<sizeof...(Args)>(
57+
[&]<int Ns_, typename Ts_>(const Ts_& ts) {
58+
if (Ns_ < n_lambda) {
59+
fdapde_static_assert(
60+
std::is_convertible_v<Ts_ FDAPDE_COMMA double>, INVALID_SMOOTHING_PARAMETER_TYPE);
61+
lambda[Ns_] = ts;
62+
}
63+
},
64+
args...);
5565
matrix_t y = y_;
5666
// initialization
5767
solver_.update_response_and_weights(y, vector_t::Ones(n_obs_).asDiagonal()); // restore solver state
58-
solver_.nonparametric_fit((2. * lambda)...);
68+
internals::apply_index_pack_and_args<sizeof...(Args)>(
69+
[&]<typename... Ts_>(Ts_... ts) {
70+
solver_.nonparametric_fit([&]() {
71+
if constexpr (Ts_::index < n_lambda) {
72+
return 2. * ts.value; // scale smoothing parameter
73+
} else {
74+
return ts.value;
75+
}
76+
}()...);
77+
},
78+
args...);
5979
mu_ = solver_.Psi() * solver_.f();
6080
double Jold = std::numeric_limits<double>::max(), Jnew = 0;
6181
n_iter_ = 0;
@@ -67,24 +87,21 @@ template <typename VariationalSolver> class QSRPDE {
6787
py_ = y - (1 - 2. * alpha) * abs_res;
6888
// \argmin_{\beta, f} [ 1/n * \norm(W^{1/2} * (y - X * \beta - f_n))^2 + P_{\lambda}(f) ]
6989
solver_.update_response_and_weights(py_, pW_.asDiagonal());
70-
solver_.fit(lambda...);
90+
solver_.fit(std::forward<Args>(args)...);
7191
mu_ = fitted();
7292
// prepare for next iteration
7393
double data_loss = (pW_.cwiseSqrt().matrix().asDiagonal() * (py_ - mu_)).squaredNorm() / n_obs_;
7494
Jold = Jnew;
75-
Jnew = data_loss + solver_.ftPf(lambda...);
95+
Jnew = data_loss + solver_.ftPf(lambda);
7696
n_iter_++;
7797
}
78-
return;
79-
}
80-
template <typename... LambdaT>
81-
requires(std::is_convertible_v<LambdaT, double> && ...) && (sizeof...(LambdaT) == n_lambda)
82-
void fit(LambdaT... lambda) {
83-
return fit(alpha_, lambda...);
98+
return std::make_pair(solver_.f(), solver_.beta());
8499
}
100+
template <typename... Args> auto fit(Args&&... args) { return fit(alpha_, std::forward<Args>(args)...); }
85101
// observers
86102
const vector_t& f() const { return solver_.f(); }
87103
const vector_t& beta() const { return solver_.beta(); }
104+
const vector_t& misfit() const { return solver_.misfit(); }
88105
int n_covs() const { return n_covs_; }
89106
int n_obs() const { return n_obs_; }
90107
double edf(int r = 100, int seed = random_seed) { return solver_.edf(r, seed); }

fdaPDE/src/models/sr.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,11 @@ class SRPDE {
4444
}
4545
solver_ = solver_t(formula, gf, penalty.get());
4646
}
47-
template <typename... LambdaT>
48-
requires(std::is_convertible_v<LambdaT, double> && ...) ||
49-
(sizeof...(LambdaT) == 1 && (internals::is_vector_like_v<LambdaT> && ...))
50-
void fit(LambdaT... lambda) {
51-
solver_.fit(lambda...);
52-
}
47+
template <typename... Args> auto fit(Args&&... args) { return solver_.fit(std::forward<Args>(args)...); }
5348
// observers
5449
const vector_t& f() const { return solver_.f(); }
5550
const vector_t& beta() const { return solver_.beta(); }
51+
const vector_t& misfit() const { return solver_.misfit(); }
5652
int n_covs() const { return n_covs_; }
5753
int n_obs() const { return n_obs_; }
5854
double edf(int r = 100, int seed = random_seed) { return solver_.edf(r, seed); }

fdaPDE/src/solvers/fe_ls_parabolic.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ class fe_ls_parabolic_mono {
304304
if (nan_pattern.any()) {
305305
n_obs_ = n_locs_ - nan_pattern.count();
306306
B_ = (~nan_pattern).repeat(1, m_ * n_dofs_).select(Psi_, 0);
307-
y_ = (~nan_pattern).select(y_, 0);
307+
y_ = (~nan_pattern).select(y_, 0);
308308
}
309309
if (old_n_obs != n_obs_) { W_ *= (double)old_n_obs / n_obs_; } // re-normalize
310310
b_.block(0, 0, m_ * n_dofs_, 1) = -PsiNA().transpose() * D_ * W_ * y;

fdaPDE/src/solvers/fe_ls_separable.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ class fe_ls_separable_mono {
294294
if (nan_pattern.any()) {
295295
n_obs_ = n_locs_ - nan_pattern.count();
296296
B_ = (~nan_pattern).repeat(1, n_dofs_).select(Psi_, 0);
297-
y_ = (~nan_pattern).select(y_, 0);
297+
y_ = (~nan_pattern).select(y_, 0);
298298
}
299299
if (old_n_obs != n_obs_) { W_ *= (double)old_n_obs / n_obs_; }
300300
b_.block(0, 0, n_dofs_, 1) = -PsiNA().transpose() * D_ * W_ * y;

0 commit comments

Comments
 (0)