Skip to content

Commit f4171b3

Browse files
committed
bug fix: solver specific parameters correctly forwarded from solver factories to solver implementation, core submodule update (bug fix on density estimation models)
1 parent d9f3138 commit f4171b3

File tree

15 files changed

+594
-415
lines changed

15 files changed

+594
-415
lines changed

fdaPDE/src/models/de.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ class DEPDE {
3232

3333
DEPDE() noexcept = default;
3434
template <typename GeoFrame, typename Penalty> DEPDE(const GeoFrame& gf, Penalty&& penalty) noexcept : solver_() {
35-
analyze_data(gf);
3635
discretize(penalty.get());
36+
analyze_data(gf);
3737
}
3838
template <typename... Args> const vector_t& fit(Args&&... args) { return solver_.fit(std::forward<Args>(args)...); }
3939
template <typename... Args> void discretize(Args&&... args) { solver_.discretize(std::forward<Args>(args)...); }
@@ -57,7 +57,7 @@ class DEPDE {
5757

5858
// deduction guide
5959
template <typename GeoFrame, typename Penalty>
60-
DEPDE(const GeoFrame& gf, Penalty&& solver) -> DEPDE<typename Penalty::solver_t>;
60+
DEPDE(const GeoFrame& gf, Penalty&& solver) -> DEPDE<typename std::decay_t<Penalty>::solver_t>;
6161

6262
} // namespace fdapde
6363

fdaPDE/src/models/fpca.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ template <typename VariationalSolver> class fpca_power_iteration_impl {
7676
} break;
7777
case OptimizeGCV: {
7878
auto gcv_functor = [&](auto lambda) { return gcv_(X, lambda, V.col(i)); };
79-
// GridOptimizer<n_lambda> optimizer;
80-
GridOptimizer<n_lambda> optimizer;
79+
GridSearch<n_lambda> optimizer;
8180
opt_lambda = optimizer.optimize(gcv_functor, lambda_grid);
8281
} break;
8382
case OptimizeMSRE: {
@@ -197,8 +196,7 @@ template <typename VariationalSolver> class fpca_subspace_iteration_impl {
197196
} break;
198197
case OptimizeGCV: {
199198
auto gcv_functor = [&](auto lambda) { return gcv_(X, rank, lambda, V); };
200-
// GridOptimizer<n_lambda> optimizer;
201-
GridOptimizer<n_lambda> optimizer;
199+
GridSearch<n_lambda> optimizer;
202200
opt_lambda = optimizer.optimize(gcv_functor, lambda_grid);
203201
} break;
204202
case OptimizeMSRE: {
@@ -312,8 +310,7 @@ template <typename VariationalSolver> class fpca_direct_impl {
312310
} break;
313311
case OptimizeGCV: {
314312
auto gcv_functor = [&](auto lambda) { return gcv_(X, rank, lambda, flag); };
315-
// GridOptimizer<n_lambda> optimizer;
316-
GridOptimizer<n_lambda> optimizer;
313+
GridSearch<n_lambda> optimizer;
317314
opt_lambda = optimizer.optimize(gcv_functor, lambda_grid);
318315
} break;
319316
case OptimizeMSRE: {

fdaPDE/src/models/gsr.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ class GSRPDE {
3333
GSRPDE() noexcept : distr_(), solver_() { }
3434
template <typename GeoFrame, typename Penalty>
3535
GSRPDE(const std::string& formula, const GeoFrame& gf, Penalty&& penalty) noexcept : distr_(), solver_() {
36-
discretize(penalty.get().penalty);
36+
discretize(penalty.get());
3737
analyze_data(formula, gf);
3838
}
3939
template <typename GeoFrame, typename Distribution, typename Penalty>
4040
GSRPDE(const std::string& formula, const GeoFrame& gf, const Distribution& distr, Penalty&& penalty) noexcept :
4141
GSRPDE(formula, gf, penalty) {
42-
discretize(penalty.get().penalty);
42+
discretize(penalty.get());
4343
analyze_data(formula, gf);
4444
set_family(distr);
4545
}

fdaPDE/src/models/qsr.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ template <typename VariationalSolver> class QSRPDE {
3232
template <typename GeoFrame, typename Penalty>
3333
QSRPDE(const std::string& formula, const GeoFrame& gf, double alpha, Penalty&& penalty) noexcept :
3434
solver_(), alpha_(alpha) {
35-
discretize(penalty.get().penalty);
35+
discretize(penalty.get());
3636
analyze_data(formula, gf);
3737
}
38-
template <typename GeoFrame, typename Penalty>
38+
template <typename GeoFrame, typename Penalty> // default to median fitting
3939
QSRPDE(const std::string& formula, const GeoFrame& gf, Penalty&& penalty) noexcept :
4040
QSRPDE(formula, gf, 0.5, penalty) { }
4141

fdaPDE/src/models/sr.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class SRPDE {
3535
template <typename GeoFrame, typename Penalty>
3636
SRPDE(const std::string& formula, const GeoFrame& gf, Penalty&& penalty) noexcept :
3737
solver_(), geo_category_(gf[0].category().begin(), gf[0].category().end()) {
38-
discretize(penalty.get().penalty);
38+
discretize(penalty.get());
3939
analyze_data(formula, gf);
4040
}
4141
// modifiers
@@ -349,7 +349,8 @@ class SRPDE {
349349

350350
// deduction guide
351351
template <typename GeoFrame, typename Penalty>
352-
SRPDE(const std::string& formula, const GeoFrame& gf, Penalty&& solver) -> SRPDE<typename Penalty::solver_t>;
352+
SRPDE(const std::string& formula, const GeoFrame& gf, Penalty&& solver)
353+
-> SRPDE<typename std::decay_t<Penalty>::solver_t>;
353354

354355
} // namespace fdapde
355356

fdaPDE/src/solvers/fe_de_elliptic.h

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@ struct fe_de_elliptic {
3131
using sparse_solver_t = eigen_sparse_solver_movable_wrap<Eigen::SparseLU<sparse_matrix_t>>;
3232
template <typename DataLocs>
3333
static constexpr bool is_valid_data_locs_descriptor_v = std::is_same_v<DataLocs, matrix_t>;
34-
template <typename InfoT> struct is_valid_info_t {
35-
static constexpr bool value = requires(InfoT info) { info.penalty; };
34+
template <typename Penalty> struct is_valid_penalty {
35+
static constexpr bool value = requires(Penalty penalty) {
36+
penalty.bilinear_form();
37+
penalty.linear_form();
38+
};
3639
};
40+
template <typename Penalty> static constexpr bool is_valid_penalty_v = is_valid_penalty<Penalty>::value;
3741
// high-order quadrature for integration of constraint \int_D (e^g)
3842
template <int EmbedDim> struct de_quadrature {
3943
using type = std::conditional_t<EmbedDim == 1, QS1DP7_, std::conditional_t<EmbedDim == 2, QS2DP4_, QS3DP5_>>;
@@ -76,30 +80,30 @@ struct fe_de_elliptic {
7680
};
7781

7882
fe_de_elliptic() noexcept = default;
79-
template <typename GeoFrame, typename InfoT>
80-
requires(is_valid_info_t<InfoT>::value)
81-
fe_de_elliptic(const GeoFrame& gf, InfoT&& info) {
83+
template <typename GeoFrame, typename Penalty>
84+
requires(is_valid_penalty_v<Penalty>)
85+
fe_de_elliptic(const GeoFrame& gf, Penalty&& penalty) {
8286
fdapde_static_assert(GeoFrame::Order == 1, THIS_CLASS_IS_FOR_ORDER_ONE_GEOFRAMES_ONLY);
83-
discretize(gf, info.penalty);
87+
discretize(penalty);
8488
analyze_data(gf);
8589
}
8690

8791
// perform finite element based numerical discretization
88-
template <typename GeoFrame, typename Penalty> void discretize(const GeoFrame& gf, Penalty&& penalty) {
89-
fdapde_static_assert(internals::is_valid_penalty_pair_v<Penalty>, INVALID_PENALTY_DESCRIPTION);
90-
using BilinearForm = std::tuple_element_t<0, std::decay_t<Penalty>>;
91-
using LinearForm = std::tuple_element_t<1, std::decay_t<Penalty>>;
92+
template <typename Penalty> void discretize(Penalty&& penalty) {
93+
using BilinearForm = typename std::decay_t<Penalty>::BilinearForm;
94+
using LinearForm = typename std::decay_t<Penalty>::LinearForm;
95+
fdapde_static_assert(
96+
internals::is_valid_penalty_pair_v<BilinearForm FDAPDE_COMMA LinearForm>, INVALID_PENALTY_DESCRIPTION);
9297
using FeSpace = typename BilinearForm::TrialSpace;
9398
using DofHandler = typename FeSpace::DofHandlerType;
9499
using Triangulation = typename FeSpace::Triangulation;
95100
constexpr int embed_dim = Triangulation::embed_dim;
96101

97102
// discretization
98-
const FeSpace& fe_space = std::get<0>(penalty).trial_space();
99-
const Triangulation& triangulation = gf.template triangulation<0>();
103+
const BilinearForm& bilinear_form = penalty.bilinear_form();
104+
const LinearForm& linear_form = penalty.linear_form();
105+
const FeSpace& fe_space = bilinear_form.trial_space();
100106
const DofHandler& dof_handler = fe_space.dof_handler();
101-
const BilinearForm& bilinear_form = std::get<0>(penalty);
102-
const LinearForm& linear_form = std::get<1>(penalty);
103107
n_dofs_ = bilinear_form.n_dofs(); // number of basis functions over physical domain
104108
internals::fe_mass_assembly_loop<FeSpace> mass_assembler(bilinear_form.trial_space());
105109
R0_ = mass_assembler.assemble();
@@ -114,7 +118,10 @@ struct fe_de_elliptic {
114118
point_eval_ = [fe_space = bilinear_form.trial_space()](const matrix_t& locs) -> decltype(auto) {
115119
return internals::point_basis_eval(fe_space, locs);
116120
};
117-
// eval reference basis at quadrature nodes, store de_quadrature weights
121+
122+
// geometry
123+
const Triangulation& triangulation = fe_space.triangulation();
124+
// eval reference basis at quadrature nodes, store de_quadrature weights
118125
de_quadrature_t<embed_dim> quad_rule;
119126
int n_quad_nodes = quad_rule.order;
120127
int n_shape_functions = fe_space.n_shape_functions();
@@ -144,15 +151,14 @@ struct fe_de_elliptic {
144151
it->measure();
145152
}
146153
return grad;
147-
};
154+
};
148155
return;
149156
}
150157
// fit from geoframe
151158
template <typename GeoFrame> void analyze_data(const GeoFrame& gf) {
152159
fdapde_static_assert(GeoFrame::Order == 1, THIS_CLASS_IS_FOR_ORDER_ONE_GEOFRAMES_ONLY);
153160
fdapde_assert(gf.n_layers() == 1 && gf[0].category()[0] == ltype::point);
154161
n_obs_ = gf[0].rows();
155-
156162
// eval physical basis at spatial locations
157163
const auto& spatial_index = geo_index_cast<0, POINT>(gf[0]);
158164
if (spatial_index.points_at_dofs()) {
@@ -164,15 +170,16 @@ struct fe_de_elliptic {
164170
return;
165171
}
166172
// main fit entry point
167-
template <typename Optimizer> const vector_t& fit(double lambda, const vector_t& g_init, Optimizer&& opt) {
168-
g_ = opt.optimize(llik_t(*this, lambda, tol_), g_init);
173+
template <typename Optimizer, typename... Callbacks>
174+
const vector_t& fit(double lambda, const vector_t& g_init, Optimizer&& opt, Callbacks&&... callbacks) {
175+
g_ = opt.optimize(llik_t(*this, lambda, tol_), g_init, std::forward<Callbacks>(callbacks)...);
169176
return g_;
170177
}
171-
template <typename Optimizer, typename LambdaT>
178+
template <typename Optimizer, typename LambdaT, typename... Callbacks>
172179
requires(internals::is_vector_like_v<LambdaT>)
173-
const vector_t& fit(LambdaT&& lambda, const vector_t& g_init, Optimizer&& opt) {
180+
const vector_t& fit(LambdaT&& lambda, const vector_t& g_init, Optimizer&& opt, Callbacks&&... callbacks) {
174181
fdapde_assert(lambda.size() == n_lambda);
175-
return fit(lambda[0]);
182+
return fit(lambda[0], g_init, opt, std::forward<Callbacks>(callbacks)...);
176183
}
177184
// modifiers
178185
void set_llik_tolerance(double tol) { tol_ = tol; }
@@ -207,21 +214,31 @@ struct fe_de_elliptic {
207214

208215
} // namespace internals
209216

210-
// elliptic solver factory
211-
template <typename BilinearForm, typename LinearForm> struct fe_de_elliptic {
217+
// elliptic solver API
218+
template <typename BilinearForm_, typename LinearForm_> struct fe_de_elliptic {
212219
using solver_t = internals::fe_de_elliptic;
213220
private:
214-
struct info_t {
215-
std::tuple<BilinearForm, LinearForm> penalty;
221+
struct penalty_packet {
222+
using BilinearForm = std::decay_t<BilinearForm_>;
223+
using LinearForm = std::decay_t<LinearForm_>;
224+
private:
225+
BilinearForm bilinear_form_;
226+
LinearForm linear_form_;
227+
public:
228+
penalty_packet(const BilinearForm_& bilinear_form, const LinearForm_& linear_form) :
229+
bilinear_form_(bilinear_form), linear_form_(linear_form) { }
230+
// observers
231+
const BilinearForm& bilinear_form() const { return bilinear_form_; }
232+
const LinearForm& linear_form() const { return linear_form_; }
216233
};
217234
public:
218-
fe_de_elliptic(const BilinearForm& bilinear_form, const LinearForm& linear_form) :
219-
info_(std::make_tuple(bilinear_form, linear_form)) { }
220-
const info_t& get() const { return info_; }
235+
fe_de_elliptic(const BilinearForm_& bilinear_form, const LinearForm_& linear_form) :
236+
penalty_(bilinear_form, linear_form) { }
237+
const penalty_packet& get() const { return penalty_; }
221238
private:
222-
info_t info_;
223-
};
224-
239+
penalty_packet penalty_;
240+
};
241+
225242
} // namespace fdapde
226243

227244
#endif // __FE_DE_ELLIPTIC_SOLVER_H__

0 commit comments

Comments
 (0)