@@ -32,29 +32,27 @@ template <typename VariationalSolver> class QSRPDE {
3232 template <typename GeoFrame, typename Penalty>
3333 QSRPDE (const std::string& formula, const GeoFrame& gf, double alpha, Penalty&& penalty) noexcept :
3434 solver_ (), alpha_(alpha) {
35+ discretize (penalty.get ().penalty );
36+ analyze_data (formula, gf);
37+ }
38+ template <typename GeoFrame, typename Penalty>
39+ QSRPDE (const std::string& formula, const GeoFrame& gf, Penalty&& penalty) noexcept :
40+ QSRPDE (formula, gf, 0.5 , penalty) { }
41+
42+ // modifiers
43+ void set_level (double alpha) { alpha_ = alpha; }
44+ template <typename ... Args> void discretize (Args&&... args) { solver_.discretize (std::forward<Args>(args)...); }
45+ template <typename GeoFrame, typename WeightMatrix>
46+ void analyze_data (const std::string& formula, const GeoFrame& gf, const WeightMatrix& W) {
3547 fdapde_assert (gf.n_layers () == 1 );
3648 Formula formula_ (formula);
3749 n_obs_ = gf[0 ].rows ();
3850 n_covs_ = 0 ;
3951 for (const std::string& token : formula_.rhs ()) {
4052 if (gf.contains (token)) { n_covs_++; }
4153 }
42- // discretize
43- if constexpr (requires (Penalty p) { p.get (); }) {
44- solver_ = solver_t (formula, gf, penalty.get ());
45- } else {
46- solver_ = solver_t (formula, gf, penalty (gf.template triangulation <0 >()).get ());
47- }
48- y_ = solver_.response ();
49- }
50-
51- // modifiers
52- template <typename ... Args> void discretize (Args&&... args) {
53- return solver_.discretize (std::forward<Args>(args)...);
54- }
55- template <typename GeoFrame, typename WeightMatrix>
56- void analyze_data (const std::string& formula, const GeoFrame& gf, const WeightMatrix& W) {
57- return solver_.analyze_data (formula, gf, W);
54+ solver_.analyze_data (formula, gf, W);
55+ y_ = solver_.response ();
5856 }
5957 template <typename GeoFrame> void analyze_data (const std::string& formula, const GeoFrame& gf) {
6058 return analyze_data (formula, gf, vector_t::Ones (gf[0 ].rows ()).asDiagonal ());
@@ -134,11 +132,21 @@ template <typename VariationalSolver> class QSRPDE {
134132 static constexpr int XprBits = 0 ;
135133 using Scalar = double ;
136134 using InputType = Vector<Scalar, StaticInputSize>;
135+ using edf_cache_t = std::unordered_map<
136+ std::array<double , StaticInputSize>, double , internals::std_array_hash<double , StaticInputSize>>;
137137
138138 gcv_t () noexcept = default ;
139- gcv_t (QSRPDE* model) : model_(model), n_(model->n_obs ()), q_(model->n_covs ()), r_(100 ), seed_(random_seed) { }
140- gcv_t (QSRPDE* model, int r, int seed) :
141- model_(model), n_(model->n_obs ()), q_(model->n_covs ()), r_(r), seed_(seed) { }
139+ gcv_t (QSRPDE* model, const edf_cache_t & edf_cache) :
140+ model_ (model),
141+ n_ (model->n_obs ()),
142+ q_(model->n_covs ()),
143+ edf_cache_(edf_cache),
144+ r_(100 ),
145+ seed_(random_seed) { }
146+ gcv_t (QSRPDE* model, const edf_cache_t & edf_cache, int r, int seed) :
147+ model_(model), n_(model->n_obs ()), q_(model->n_covs ()), edf_cache_(edf_cache), r_(r), seed_(seed) { }
148+ gcv_t (QSRPDE* model) : gcv_t(model, edf_cache_t ()) { }
149+ gcv_t (QSRPDE* model, int r, int seed) : gcv_t(model, edf_cache_t (), r, seed) { }
142150
143151 template <typename InputType_>
144152 requires (internals::is_subscriptable<InputType_, int >)
@@ -150,28 +158,31 @@ template <typename VariationalSolver> class QSRPDE {
150158 constexpr double operator()(LambdaT... lambda) {
151159 model_->fit (static_cast <double >(lambda)...);
152160 std::array<double , StaticInputSize> lambda_vec {lambda...};
153- if (edf_map_ .find (lambda_vec) == edf_map_ .end ()) { // cache Tr[S]
154- edf_map_ [lambda_vec] = model_->edf (r_, seed_);
161+ if (edf_cache_ .find (lambda_vec) == edf_cache_ .end ()) { // cache Tr[S]
162+ edf_cache_ [lambda_vec] = model_->edf (r_, seed_);
155163 }
156- double dor = n_ - (q_ + edf_map_ .at (lambda_vec)); // residual degrees of freedom
164+ double dor = n_ - (q_ + edf_cache_ .at (lambda_vec)); // residual degrees of freedom
157165 double pinball = 0 ;
158166 for (int i = 0 ; i < n_; ++i) {
159167 pinball += model_->pinball_loss (model_->y_ [i] - model_->mu_ [i], std::pow (10 , model_->eps_ ));
160168 }
161169 return (std::pow (pinball, 2 ) / std::pow (dor, 2 ));
162170 }
171+ // observers
172+ const edf_cache_t & edf_cache () const { return edf_cache_; }
173+ edf_cache_t & edf_cache () { return edf_cache_; }
163174 private:
164175 QSRPDE* model_;
165176 int n_ = 0 , q_ = 0 ;
166- std::unordered_map<
167- std::array<double , StaticInputSize>, double , internals::std_array_hash<double , StaticInputSize>>
168- edf_map_;
177+ edf_cache_t edf_cache_;
169178 // stochastic edf approximation parameter
170179 int r_, seed_;
171180 };
172181 friend gcv_t ;
173182 gcv_t gcv () { return gcv_t (this ); }
183+ gcv_t gcv (const typename gcv_t ::edf_cache_t & edf_cache) { return gcv_t (this , edf_cache); }
174184 gcv_t gcv (int r, int seed) { return gcv_t (this , r, seed); }
185+ gcv_t gcv (const typename gcv_t ::edf_cache_t & edf_cache, int r, int seed) { return gcv_t (this , edf_cache, r, seed); }
175186
176187 // inference
177188
0 commit comments