Skip to content

Commit ab2a789

Browse files
authored
Refactor: add smooth threshold support for bpcg method (#5709)
* add smooth threshold support for bpcg method * fix build test bug * fix build bug
1 parent c5d5eab commit ab2a789

File tree

4 files changed

+19
-18
lines changed

4 files changed

+19
-18
lines changed

source/module_hsolver/diago_bpcg.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@ void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
5555
}
5656

5757
template<typename T, typename Device>
58-
bool DiagoBPCG<T, Device>::test_error(const ct::Tensor& err_in, Real thr_in)
58+
bool DiagoBPCG<T, Device>::test_error(const ct::Tensor& err_in, const std::vector<double>& ethr_band)
5959
{
6060
const Real * _err_st = err_in.data<Real>();
6161
if (err_in.device_type() == ct::DeviceType::GpuDevice) {
6262
ct::Tensor h_err_in = err_in.to_device<ct::DEVICE_CPU>();
6363
_err_st = h_err_in.data<Real>();
6464
}
6565
for (int ii = 0; ii < this->n_band; ii++) {
66-
if (_err_st[ii] > thr_in) {
66+
if (_err_st[ii] > ethr_band[ii]) {
6767
return true;
6868
}
6969
}
@@ -242,11 +242,11 @@ void DiagoBPCG<T, Device>::calc_hsub_with_block_exit(
242242
return;
243243
}
244244

245-
template<typename T, typename Device>
246-
void DiagoBPCG<T, Device>::diag(
247-
const HPsiFunc& hpsi_func,
248-
T *psi_in,
249-
Real* eigenvalue_in)
245+
template <typename T, typename Device>
246+
void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
247+
T* psi_in,
248+
Real* eigenvalue_in,
249+
const std::vector<double>& ethr_band)
250250
{
251251
const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
252252
// Get the pointer of the input psi
@@ -301,7 +301,7 @@ void DiagoBPCG<T, Device>::diag(
301301
if (current_scf_iter == 1 && ntry % this->nline == 0) {
302302
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
303303
}
304-
} while (ntry < max_iter && this->test_error(this->err_st, this->all_band_cg_thr));
304+
} while (ntry < max_iter && this->test_error(this->err_st, ethr_band));
305305

306306
this->calc_hsub_with_block_exit(this->psi, this->hpsi, this->hsub, this->work, this->eigen);
307307

source/module_hsolver/diago_bpcg.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@ class DiagoBPCG
6767
* @param psi_in Pointer to input wavefunction psi matrix with [dim: n_basis x n_band, column major].
6868
* @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major].
6969
*/
70-
void diag(const HPsiFunc& hpsi_func, T *psi_in, Real *eigenvalue_in);
71-
70+
void diag(const HPsiFunc& hpsi_func,
71+
T* psi_in,
72+
Real* eigenvalue_in,
73+
const std::vector<double>& ethr_band);
7274

7375
private:
7476
/// the number of rows of the input psi
@@ -77,8 +79,6 @@ class DiagoBPCG
7779
int n_basis = 0;
7880
/// max iter steps for all-band cg loop
7981
int nline = 4;
80-
/// cg convergence thr
81-
Real all_band_cg_thr = 1E-5;
8282

8383
ct::DataType r_type = ct::DataType::DT_INVALID;
8484
ct::DataType t_type = ct::DataType::DT_INVALID;
@@ -316,7 +316,7 @@ class DiagoBPCG
316316
* @param thr_in The threshold.
317317
* @return Returns true if all error values are less than or equal to the threshold, false otherwise.
318318
*/
319-
bool test_error(const ct::Tensor& err_in, Real thr_in);
319+
bool test_error(const ct::Tensor& err_in, const std::vector<double>& ethr_band);
320320

321321
using ct_Device = typename ct::PsiToContainer<Device>::type;
322322
using setmem_var_op = ct::kernels::set_memory<Real, ct_Device>;

source/module_hsolver/hsolver_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
493493
};
494494
DiagoBPCG<T, Device> bpcg(pre_condition.data());
495495
bpcg.init_iter(nband, nbasis);
496-
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue);
496+
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band);
497497
}
498498
else if (this->method == "dav_subspace")
499499
{

source/module_hsolver/test/diago_bpcg_test.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,11 @@ class DiagoBPCGPrepare
152152
hpsi_out, ld_psi);
153153
};
154154
bpcg.init_iter(nband, npw);
155-
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
156-
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
157-
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
158-
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
155+
std::vector<double> ethr_band(nband, 1e-5);
156+
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
157+
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
158+
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
159+
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
159160
end = MPI_Wtime();
160161
//if(mypnum == 0) printf("diago time:%7.3f\n",end-start);
161162
delete [] DIAGOTEST::npw_local;

0 commit comments

Comments
 (0)