@@ -78,6 +78,20 @@ struct fe_ls_elliptic {
7878 }
7979 return ;
8080 }
81+ void enforce_lhs_dirichlet_bc_ (SparseBlockMatrix<double , 2 , 2 >& A) {
82+ if (dirichlet_dofs_.size () == 0 ) { return ; }
83+ for (int i = 0 ; i < dirichlet_dofs_.size (); ++i) {
84+ // zero out row and column in correspondance of Dirichlet-type dofs
85+ A.row (dirichlet_dofs_[i]) *= 0 ;
86+ A.col (dirichlet_dofs_[i]) *= 0 ;
87+ A.row (n_dofs_ + dirichlet_dofs_[i]) *= 0 ;
88+ A.col (n_dofs_ + dirichlet_dofs_[i]) *= 0 ;
89+ // set diagonal elements to 1
90+ A.coeffRef (dirichlet_dofs_[i], dirichlet_dofs_[i]) = 1 ;
91+ A.coeffRef (n_dofs_ + dirichlet_dofs_[i], n_dofs_ + dirichlet_dofs_[i]) = 1 ;
92+ }
93+ return ;
94+ }
8195 public:
8296 static constexpr int n_lambda = 1 ;
8397 using solver_category = ls_solver;
@@ -138,6 +152,10 @@ struct fe_ls_elliptic {
138152 return internals::areal_basis_eval (fe_space, locs);
139153 };
140154 b_.resize (2 * n_dofs_, 1 );
155+ // store Dirichlet boundary condition
156+ auto & dof_handler = bilinear_form.trial_space ().dof_handler ();
157+ dirichlet_dofs_ = dof_handler.dirichlet_dofs ();
158+ dirichlet_vals_ = dof_handler.dirichlet_values ();
141159 return ;
142160 }
143161 // non-parametric fit
@@ -219,6 +237,8 @@ struct fe_ls_elliptic {
219237 }
220238 if (old_n_obs != n_obs_) { W_ *= (double )old_n_obs / n_obs_; }
221239 b_.block (0 , 0 , n_dofs_, 1 ) = -PsiNA ().transpose () * D_ * W_ * y;
240+ // enforce dirichlet bc, if any
241+ for (int i = 0 ; i < dirichlet_dofs_.size (); ++i) { b_.row (dirichlet_dofs_[i]).setConstant (dirichlet_vals_[i]); }
222242 return ;
223243 }
224244 template <typename WeightMatrix> void update_weights (const WeightMatrix& W) {
@@ -236,8 +256,10 @@ struct fe_ls_elliptic {
236256 V_.block (0 , 0 , n_covs_, n_dofs_) = X_.transpose () * W_ * PsiNA ();
237257 b_.block (0 , 0 , n_dofs_, 1 ) = -PsiNA ().transpose () * D_ * internals::lmbQ (W_, X_, invXtWX_, y_);
238258 }
239- W_changed_ = true ;
240- return ;
259+ // enforce dirichlet bc, if any
260+ for (int i = 0 ; i < dirichlet_dofs_.size (); ++i) { b_.row (dirichlet_dofs_[i]).setConstant (dirichlet_vals_[i]); }
261+ W_changed_ = true ;
262+ return ;
241263 }
242264 template <typename WeightMatrix> void update_response_and_weights (const vector_t & y, const WeightMatrix& W) {
243265 fdapde_assert (
@@ -261,12 +283,14 @@ struct fe_ls_elliptic {
261283 // assemble and factorize system matrix for nonparameteric part
262284 SparseBlockMatrix<double , 2 , 2 > A (
263285 -PsiNA ().transpose () * D_ * W_ * PsiNA (), lambda * R1_.transpose (), lambda * R1_, lambda * R0_);
286+ enforce_lhs_dirichlet_bc_ (A);
264287 invA_.compute (A);
265288 W_changed_ = false ;
266289 }
267290 if (lambda_saved_.value () != lambda) {
268291 // update linear system rhs
269292 b_.block (n_dofs_, 0 , n_dofs_, 1 ) = lambda * u_;
293+ for (int i = 0 ; i < dirichlet_dofs_.size (); ++i) { b_.row (n_dofs_ + dirichlet_dofs_[i]).setZero (); }
270294 }
271295 lambda_saved_ = lambda;
272296 vector_t x;
@@ -294,17 +318,26 @@ struct fe_ls_elliptic {
294318 // assemble and factorize system matrix for nonparameteric part
295319 SparseBlockMatrix<double , 2 , 2 > A (
296320 -PsiNA ().transpose () * D_ * W_ * PsiNA (), lambda * R1_.transpose (), lambda * R1_, lambda * R0_);
321+ enforce_lhs_dirichlet_bc_ (A);
297322 invA_.compute (A);
298323 }
299324 vector_t x;
300325 if (n_covs_ == 0 ) { // equivalent to calling fit(lambda)
301- if (lambda_saved_.value () != lambda) { b_.block (n_dofs_, 0 , n_dofs_, 1 ) = lambda * u_; }
326+ if (lambda_saved_.value () != lambda) {
327+ b_.block (n_dofs_, 0 , n_dofs_, 1 ) = lambda * u_;
328+ for (int i = 0 ; i < dirichlet_dofs_.size (); ++i) { b_.row (n_dofs_ + dirichlet_dofs_[i]).setZero (); }
329+ }
302330 x = invA_.solve (b_);
303331 } else {
304332 vector_t b (2 * n_dofs_);
305333 // assemble nonparametric linear system rhs
306334 b.block (0 , 0 , n_dofs_, 1 ) = -PsiNA ().transpose () * D_ * W_ * y_;
307- b.block (n_dofs_, 0 , n_dofs_, 1 ) = lambda * u_;
335+ b.block (n_dofs_, 0 , n_dofs_, 1 ) = lambda * u_;
336+ // enforce Dirichlet BCs, if any
337+ for (int i = 0 ; i < dirichlet_dofs_.size (); ++i) {
338+ b_.row (dirichlet_dofs_[i]).setConstant (dirichlet_vals_[i]);
339+ b_.row (n_dofs_ + dirichlet_dofs_[i]).setZero ();
340+ }
308341 x = invA_.solve (b);
309342 }
310343 lambda_saved_ = lambda;
@@ -332,6 +365,10 @@ struct fe_ls_elliptic {
332365 } else {
333366 Bs_->topRows (n_dofs_) = -PsiNA ().transpose () * D_ * internals::lmbQ (W_, X_, invXtWX_, *Us_);
334367 }
368+ // enforce Dirichlet BCs, if any
369+ for (int i = 0 ; i < dirichlet_dofs_.size (); ++i) {
370+ Bs_->row (dirichlet_dofs_[i]).setConstant (dirichlet_vals_[i]);
371+ }
335372 matrix_t x = n_covs_ == 0 ? invA_.solve (*Bs_) : woodbury_system_solve (invA_, U_, XtWX_, V_, *Bs_);
336373 double trS = 0 ; // monte carlo Tr[S] approximation
337374 for (int i = 0 ; i < r; ++i) { trS += Ys_->row (i).dot (x.col (i).head (n_dofs_)); }
@@ -349,9 +386,10 @@ struct fe_ls_elliptic {
349386 lambda_ = lambda;
350387 }
351388 if (lambda_saved_.value () != lambda_) {
352- SparseBlockMatrix<double , 2 , 2 > A_ (
389+ SparseBlockMatrix<double , 2 , 2 > A (
353390 -PsiNA ().transpose () * D_ * W_ * PsiNA (), lambda_ * R1_.transpose (), lambda_ * R1_, lambda_ * R0_);
354- invA_.compute (A_);
391+ enforce_lhs_dirichlet_bc_ (A);
392+ invA_.compute (A);
355393 lambda_saved_ = lambda_;
356394 }
357395 return edf (r, seed);
@@ -408,11 +446,9 @@ struct fe_ls_elliptic {
408446 std::optional<double > lambda_saved_ = -1 ;
409447 sparse_solver_t invA_;
410448 matrix_t b_;
411- // matrices for hutchinson stochastic estimation of Tr[S]
412- std::optional<matrix_t > Ys_;
413- std::optional<matrix_t > Bs_;
414- std::optional<matrix_t > Us_;
415-
449+ // matrices for Hutchinson stochastic estimation of Tr[S]
450+ std::optional<matrix_t > Ys_, Bs_, Us_;
451+
416452 int n_dofs_ = 0 , n_locs_ = 0 , n_obs_ = 0 , n_covs_ = 0 ;
417453 sparse_matrix_t R0_; // n_dofs x n_dofs matrix [R0]_{ij} = \int_D \psi_i * \psi_j
418454 sparse_matrix_t R1_; // n_dofs x n_dofs matrix [R1]_{ij} = \int_D a(\psi_i, \psi_j)
@@ -425,6 +461,8 @@ struct fe_ls_elliptic {
425461 // basis system evaluation handles
426462 std::function<sparse_matrix_t (const matrix_t & locs)> point_eval_;
427463 std::function<std::pair<sparse_matrix_t , vector_t >(const binary_t & locs)> areal_eval_;
464+ std::vector<int > dirichlet_dofs_; // dofs where Dirichlet boundary conditions are imposed
465+ std::vector<double > dirichlet_vals_; // values imposed at Dirichlet dofs
428466
429467 matrix_t X_; // n_obs x n_covs design matrix
430468 vector_t y_; // n_obs x 1 observation vector
0 commit comments