diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index aaa21877bc..635e3a7943 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -55,7 +55,7 @@ void DiagoBPCG::init_iter(const int nband, const int nbasis) { } template -bool DiagoBPCG::test_error(const ct::Tensor& err_in, Real thr_in) +bool DiagoBPCG::test_error(const ct::Tensor& err_in, const std::vector& ethr_band) { const Real * _err_st = err_in.data(); if (err_in.device_type() == ct::DeviceType::GpuDevice) { @@ -63,7 +63,7 @@ bool DiagoBPCG::test_error(const ct::Tensor& err_in, Real thr_in) _err_st = h_err_in.data(); } for (int ii = 0; ii < this->n_band; ii++) { - if (_err_st[ii] > thr_in) { + if (_err_st[ii] > ethr_band[ii]) { return true; } } @@ -242,11 +242,11 @@ void DiagoBPCG::calc_hsub_with_block_exit( return; } -template -void DiagoBPCG::diag( - const HPsiFunc& hpsi_func, - T *psi_in, - Real* eigenvalue_in) +template +void DiagoBPCG::diag(const HPsiFunc& hpsi_func, + T* psi_in, + Real* eigenvalue_in, + const std::vector& ethr_band) { const int current_scf_iter = hsolver::DiagoIterAssist::SCF_ITER; // Get the pointer of the input psi @@ -301,7 +301,7 @@ void DiagoBPCG::diag( if (current_scf_iter == 1 && ntry % this->nline == 0) { this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen); } - } while (ntry < max_iter && this->test_error(this->err_st, this->all_band_cg_thr)); + } while (ntry < max_iter && this->test_error(this->err_st, ethr_band)); this->calc_hsub_with_block_exit(this->psi, this->hpsi, this->hsub, this->work, this->eigen); diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index 2ca8167f9e..c57ed5e5ee 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -67,8 +67,10 @@ class DiagoBPCG * @param psi_in Pointer to input wavefunction psi matrix with [dim: n_basis x n_band, column major]. * @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major]. */ - void diag(const HPsiFunc& hpsi_func, T *psi_in, Real *eigenvalue_in); - + void diag(const HPsiFunc& hpsi_func, + T* psi_in, + Real* eigenvalue_in, + const std::vector& ethr_band); private: /// the number of rows of the input psi @@ -77,8 +79,6 @@ class DiagoBPCG int n_basis = 0; /// max iter steps for all-band cg loop int nline = 4; - /// cg convergence thr - Real all_band_cg_thr = 1E-5; ct::DataType r_type = ct::DataType::DT_INVALID; ct::DataType t_type = ct::DataType::DT_INVALID; @@ -316,7 +316,7 @@ class DiagoBPCG * @param thr_in The threshold. * @return Returns true if all error values are less than or equal to the threshold, false otherwise. */ - bool test_error(const ct::Tensor& err_in, Real thr_in); + bool test_error(const ct::Tensor& err_in, const std::vector& ethr_band); using ct_Device = typename ct::PsiToContainer::type; using setmem_var_op = ct::kernels::set_memory; diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 44f5571116..01a50e84e8 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -493,7 +493,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, }; DiagoBPCG bpcg(pre_condition.data()); bpcg.init_iter(nband, nbasis); - bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue); + bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band); } else if (this->method == "dav_subspace") { diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index d15b539c7d..0ca9ff2444 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -152,10 +152,11 @@ class DiagoBPCGPrepare hpsi_out, ld_psi); }; bpcg.init_iter(nband, npw); - bpcg.diag(hpsi_func, psi_local.get_pointer(), en); - bpcg.diag(hpsi_func, psi_local.get_pointer(), en); - bpcg.diag(hpsi_func, psi_local.get_pointer(), en); - bpcg.diag(hpsi_func, psi_local.get_pointer(), en); + std::vector ethr_band(nband, 1e-5); + bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); + bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); + bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); + bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band); end = MPI_Wtime(); //if(mypnum == 0) printf("diago time:%7.3f\n",end-start); delete [] DIAGOTEST::npw_local;