@@ -55,15 +55,15 @@ void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
5555}
5656
5757template <typename T, typename Device>
58- bool DiagoBPCG<T, Device>::test_error(const ct::Tensor& err_in, Real thr_in )
58+ bool DiagoBPCG<T, Device>::test_error(const ct::Tensor& err_in, const std::vector< double >& ethr_band )
5959{
6060 const Real * _err_st = err_in.data <Real>();
6161 if (err_in.device_type () == ct::DeviceType::GpuDevice) {
6262 ct::Tensor h_err_in = err_in.to_device <ct::DEVICE_CPU>();
6363 _err_st = h_err_in.data <Real>();
6464 }
6565 for (int ii = 0 ; ii < this ->n_band ; ii++) {
66- if (_err_st[ii] > thr_in ) {
66+ if (_err_st[ii] > ethr_band[ii] ) {
6767 return true ;
6868 }
6969 }
@@ -242,11 +242,11 @@ void DiagoBPCG<T, Device>::calc_hsub_with_block_exit(
242242 return ;
243243}
244244
245- template <typename T, typename Device>
246- void DiagoBPCG<T, Device>::diag(
247- const HPsiFunc& hpsi_func ,
248- T *psi_in ,
249- Real* eigenvalue_in )
245+ template <typename T, typename Device>
246+ void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
247+ T* psi_in ,
248+ Real* eigenvalue_in ,
249+ const std::vector< double >& ethr_band )
250250{
251251 const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
252252 // Get the pointer of the input psi
@@ -301,7 +301,7 @@ void DiagoBPCG<T, Device>::diag(
301301 if (current_scf_iter == 1 && ntry % this ->nline == 0 ) {
302302 this ->calc_hsub_with_block (hpsi_func, psi_in, this ->psi , this ->hpsi , this ->hsub , this ->work , this ->eigen );
303303 }
304- } while (ntry < max_iter && this ->test_error (this ->err_st , this -> all_band_cg_thr ));
304+ } while (ntry < max_iter && this ->test_error (this ->err_st , ethr_band ));
305305
306306 this ->calc_hsub_with_block_exit (this ->psi , this ->hpsi , this ->hsub , this ->work , this ->eigen );
307307
0 commit comments