@@ -54,7 +54,10 @@ DiagoCG<T, Device>::~DiagoCG()
5454}
5555
5656template <typename T, typename Device>
57- void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in, ct::Tensor& psi, ct::Tensor& eigen)
57+ void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in,
58+ ct::Tensor& psi,
59+ ct::Tensor& eigen,
60+ const std::vector<double >& ethr_band)
5861{
5962 ModuleBase::TITLE (" DiagoCG" , " diag_once" );
6063 ModuleBase::timer::tick (" DiagoCG" , " diag_once" );
@@ -153,6 +156,7 @@ void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in, ct::Tensor& psi, c
153156 converged = this ->update_psi (pphi,
154157 cg,
155158 scg, // const Tensor&
159+ ethr_band[m],
156160 cg_norm,
157161 theta,
158162 eigen_pack[m], // Real&
@@ -392,6 +396,7 @@ template <typename T, typename Device>
392396bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
393397 const ct::Tensor& cg,
394398 const ct::Tensor& scg,
399+ const double & ethreshold,
395400 Real& cg_norm,
396401 Real& theta,
397402 Real& eigen,
@@ -441,7 +446,7 @@ bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
441446 cg.data <T>(),
442447 sint_norm);
443448
444- if (std::abs (eigen - e0 ) < pw_diag_thr_ )
449+ if (std::abs (eigen - e0 ) < ethreshold )
445450 {
446451 // ModuleBase::timer::tick("DiagoCG","update");
447452 return true ;
@@ -582,6 +587,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
582587 const Func& spsi_func,
583588 ct::Tensor& psi,
584589 ct::Tensor& eigen,
590+ const std::vector<double >& ethr_band,
585591 const ct::Tensor& prec)
586592{
587593 // / record the times of trying iterative diagonalization
@@ -603,7 +609,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
603609
604610 ++ntry;
605611 avg_iter_ += 1.0 ;
606- this ->diag_mock (prec, psi_temp, eigen);
612+ this ->diag_mock (prec, psi_temp, eigen, ethr_band );
607613 } while (this ->test_exit_cond (ntry, this ->notconv_ ));
608614
609615 if (this ->notconv_ > std::max (5 , this ->n_band_ / 4 ))
0 commit comments