@@ -31,9 +31,13 @@ struct fe_de_elliptic {
3131 using sparse_solver_t = eigen_sparse_solver_movable_wrap<Eigen::SparseLU<sparse_matrix_t >>;
3232 template <typename DataLocs>
3333 static constexpr bool is_valid_data_locs_descriptor_v = std::is_same_v<DataLocs, matrix_t >;
34- template <typename InfoT> struct is_valid_info_t {
35- static constexpr bool value = requires (InfoT info) { info.penalty ; };
34+ template <typename Penalty> struct is_valid_penalty {
35+ static constexpr bool value = requires (Penalty penalty) {
36+ penalty.bilinear_form ();
37+ penalty.linear_form ();
38+ };
3639 };
40+ template <typename Penalty> static constexpr bool is_valid_penalty_v = is_valid_penalty<Penalty>::value;
3741 // high-order quadrature for integration of constraint \int_D (e^g)
3842 template <int EmbedDim> struct de_quadrature {
3943 using type = std::conditional_t <EmbedDim == 1 , QS1DP7_, std::conditional_t <EmbedDim == 2 , QS2DP4_, QS3DP5_>>;
@@ -76,30 +80,30 @@ struct fe_de_elliptic {
7680 };
7781
7882 fe_de_elliptic () noexcept = default ;
79- template <typename GeoFrame, typename InfoT >
80- requires (is_valid_info_t <InfoT>::value )
81- fe_de_elliptic (const GeoFrame& gf, InfoT && info ) {
83+ template <typename GeoFrame, typename Penalty >
84+ requires (is_valid_penalty_v<Penalty> )
85+ fe_de_elliptic (const GeoFrame& gf, Penalty && penalty ) {
8286 fdapde_static_assert (GeoFrame::Order == 1 , THIS_CLASS_IS_FOR_ORDER_ONE_GEOFRAMES_ONLY);
83- discretize (gf, info. penalty );
87+ discretize (penalty);
8488 analyze_data (gf);
8589 }
8690
8791 // perform finite element based numerical discretization
88- template <typename GeoFrame, typename Penalty> void discretize (const GeoFrame& gf, Penalty&& penalty) {
89- fdapde_static_assert (internals::is_valid_penalty_pair_v<Penalty>, INVALID_PENALTY_DESCRIPTION);
90- using BilinearForm = std::tuple_element_t <0 , std::decay_t <Penalty>>;
91- using LinearForm = std::tuple_element_t <1 , std::decay_t <Penalty>>;
92+ template <typename Penalty> void discretize (Penalty&& penalty) {
93+ using BilinearForm = typename std::decay_t <Penalty>::BilinearForm;
94+ using LinearForm = typename std::decay_t <Penalty>::LinearForm;
95+ fdapde_static_assert (
96+ internals::is_valid_penalty_pair_v<BilinearForm FDAPDE_COMMA LinearForm>, INVALID_PENALTY_DESCRIPTION);
9297 using FeSpace = typename BilinearForm::TrialSpace;
9398 using DofHandler = typename FeSpace::DofHandlerType;
9499 using Triangulation = typename FeSpace::Triangulation;
95100 constexpr int embed_dim = Triangulation::embed_dim;
96101
97102 // discretization
98- const FeSpace& fe_space = std::get<0 >(penalty).trial_space ();
99- const Triangulation& triangulation = gf.template triangulation <0 >();
103+ const BilinearForm& bilinear_form = penalty.bilinear_form ();
104+ const LinearForm& linear_form = penalty.linear_form ();
105+ const FeSpace& fe_space = bilinear_form.trial_space ();
100106 const DofHandler& dof_handler = fe_space.dof_handler ();
101- const BilinearForm& bilinear_form = std::get<0 >(penalty);
102- const LinearForm& linear_form = std::get<1 >(penalty);
103107 n_dofs_ = bilinear_form.n_dofs (); // number of basis functions over physical domain
104108 internals::fe_mass_assembly_loop<FeSpace> mass_assembler (bilinear_form.trial_space ());
105109 R0_ = mass_assembler.assemble ();
@@ -114,7 +118,10 @@ struct fe_de_elliptic {
114118 point_eval_ = [fe_space = bilinear_form.trial_space ()](const matrix_t & locs) -> decltype (auto ) {
115119 return internals::point_basis_eval (fe_space, locs);
116120 };
117- // eval reference basis at quadrature nodes, store de_quadrature weights
121+
122+ // geometry
123+ const Triangulation& triangulation = fe_space.triangulation ();
124+ // eval reference basis at quadrature nodes, store de_quadrature weights
118125 de_quadrature_t <embed_dim> quad_rule;
119126 int n_quad_nodes = quad_rule.order ;
120127 int n_shape_functions = fe_space.n_shape_functions ();
@@ -144,15 +151,14 @@ struct fe_de_elliptic {
144151 it->measure ();
145152 }
146153 return grad;
147- };
154+ };
148155 return ;
149156 }
150157 // fit from geoframe
151158 template <typename GeoFrame> void analyze_data (const GeoFrame& gf) {
152159 fdapde_static_assert (GeoFrame::Order == 1 , THIS_CLASS_IS_FOR_ORDER_ONE_GEOFRAMES_ONLY);
153160 fdapde_assert (gf.n_layers () == 1 && gf[0 ].category ()[0 ] == ltype::point);
154161 n_obs_ = gf[0 ].rows ();
155-
156162 // eval physical basis at spatial locations
157163 const auto & spatial_index = geo_index_cast<0 , POINT>(gf[0 ]);
158164 if (spatial_index.points_at_dofs ()) {
@@ -164,15 +170,16 @@ struct fe_de_elliptic {
164170 return ;
165171 }
166172 // main fit entry point
167- template <typename Optimizer> const vector_t & fit (double lambda, const vector_t & g_init, Optimizer&& opt) {
168- g_ = opt.optimize (llik_t (*this , lambda, tol_), g_init);
173+ template <typename Optimizer, typename ... Callbacks>
174+ const vector_t & fit (double lambda, const vector_t & g_init, Optimizer&& opt, Callbacks&&... callbacks) {
175+ g_ = opt.optimize (llik_t (*this , lambda, tol_), g_init, std::forward<Callbacks>(callbacks)...);
169176 return g_;
170177 }
171- template <typename Optimizer, typename LambdaT>
178+ template <typename Optimizer, typename LambdaT, typename ... Callbacks >
172179 requires (internals::is_vector_like_v<LambdaT>)
173- const vector_t & fit (LambdaT&& lambda, const vector_t & g_init, Optimizer&& opt) {
180+ const vector_t & fit (LambdaT&& lambda, const vector_t & g_init, Optimizer&& opt, Callbacks&&... callbacks ) {
174181 fdapde_assert (lambda.size () == n_lambda);
175- return fit (lambda[0 ]);
182+ return fit (lambda[0 ], g_init, opt, std::forward<Callbacks>(callbacks)... );
176183 }
177184 // modifiers
178185 void set_llik_tolerance (double tol) { tol_ = tol; }
@@ -207,21 +214,31 @@ struct fe_de_elliptic {
207214
208215} // namespace internals
209216
210- // elliptic solver factory
211- template <typename BilinearForm , typename LinearForm > struct fe_de_elliptic {
217+ // elliptic solver API
218+ template <typename BilinearForm_ , typename LinearForm_ > struct fe_de_elliptic {
212219 using solver_t = internals::fe_de_elliptic;
213220 private:
214- struct info_t {
215- std::tuple<BilinearForm, LinearForm> penalty;
221+ struct penalty_packet {
222+ using BilinearForm = std::decay_t <BilinearForm_>;
223+ using LinearForm = std::decay_t <LinearForm_>;
224+ private:
225+ BilinearForm bilinear_form_;
226+ LinearForm linear_form_;
227+ public:
228+ penalty_packet (const BilinearForm_& bilinear_form, const LinearForm_& linear_form) :
229+ bilinear_form_ (bilinear_form), linear_form_(linear_form) { }
230+ // observers
231+ const BilinearForm& bilinear_form () const { return bilinear_form_; }
232+ const LinearForm& linear_form () const { return linear_form_; }
216233 };
217234 public:
218- fe_de_elliptic (const BilinearForm & bilinear_form, const LinearForm & linear_form) :
219- info_ (std::make_tuple( bilinear_form, linear_form) ) { }
220- const info_t & get () const { return info_ ; }
235+ fe_de_elliptic (const BilinearForm_ & bilinear_form, const LinearForm_ & linear_form) :
236+ penalty_ ( bilinear_form, linear_form) { }
237+ const penalty_packet & get () const { return penalty_ ; }
221238 private:
222- info_t info_ ;
223- };
224-
239+ penalty_packet penalty_ ;
240+ };
241+
225242} // namespace fdapde
226243
227244#endif // __FE_DE_ELLIPTIC_SOLVER_H__
0 commit comments