Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions source/module_hsolver/diago_bpcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ void DiagoBPCG<T, Device>::init_iter(const int nband, const int nbasis) {
}

template<typename T, typename Device>
bool DiagoBPCG<T, Device>::test_error(const ct::Tensor& err_in, Real thr_in)
bool DiagoBPCG<T, Device>::test_error(const ct::Tensor& err_in, const std::vector<double>& ethr_band)
{
const Real * _err_st = err_in.data<Real>();
if (err_in.device_type() == ct::DeviceType::GpuDevice) {
ct::Tensor h_err_in = err_in.to_device<ct::DEVICE_CPU>();
_err_st = h_err_in.data<Real>();
}
for (int ii = 0; ii < this->n_band; ii++) {
if (_err_st[ii] > thr_in) {
if (_err_st[ii] > ethr_band[ii]) {
return true;
}
}
Expand Down Expand Up @@ -242,11 +242,11 @@ void DiagoBPCG<T, Device>::calc_hsub_with_block_exit(
return;
}

template<typename T, typename Device>
void DiagoBPCG<T, Device>::diag(
const HPsiFunc& hpsi_func,
T *psi_in,
Real* eigenvalue_in)
template <typename T, typename Device>
void DiagoBPCG<T, Device>::diag(const HPsiFunc& hpsi_func,
T* psi_in,
Real* eigenvalue_in,
const std::vector<double>& ethr_band)
{
const int current_scf_iter = hsolver::DiagoIterAssist<T, Device>::SCF_ITER;
// Get the pointer of the input psi
Expand Down Expand Up @@ -301,7 +301,7 @@ void DiagoBPCG<T, Device>::diag(
if (current_scf_iter == 1 && ntry % this->nline == 0) {
this->calc_hsub_with_block(hpsi_func, psi_in, this->psi, this->hpsi, this->hsub, this->work, this->eigen);
}
} while (ntry < max_iter && this->test_error(this->err_st, this->all_band_cg_thr));
} while (ntry < max_iter && this->test_error(this->err_st, ethr_band));

this->calc_hsub_with_block_exit(this->psi, this->hpsi, this->hsub, this->work, this->eigen);

Expand Down
10 changes: 5 additions & 5 deletions source/module_hsolver/diago_bpcg.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ class DiagoBPCG
* @param psi_in Pointer to input wavefunction psi matrix with [dim: n_basis x n_band, column major].
* @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major].
*/
void diag(const HPsiFunc& hpsi_func, T *psi_in, Real *eigenvalue_in);

void diag(const HPsiFunc& hpsi_func,
T* psi_in,
Real* eigenvalue_in,
const std::vector<double>& ethr_band);

private:
/// the number of rows of the input psi
Expand All @@ -77,8 +79,6 @@ class DiagoBPCG
int n_basis = 0;
/// max iter steps for all-band cg loop
int nline = 4;
/// cg convergence thr
Real all_band_cg_thr = 1E-5;

ct::DataType r_type = ct::DataType::DT_INVALID;
ct::DataType t_type = ct::DataType::DT_INVALID;
Expand Down Expand Up @@ -316,7 +316,7 @@ class DiagoBPCG
* @param thr_in The threshold.
* @return Returns true if all error values are less than or equal to the threshold, false otherwise.
*/
bool test_error(const ct::Tensor& err_in, Real thr_in);
bool test_error(const ct::Tensor& err_in, const std::vector<double>& ethr_band);

using ct_Device = typename ct::PsiToContainer<Device>::type;
using setmem_var_op = ct::kernels::set_memory<Real, ct_Device>;
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
};
DiagoBPCG<T, Device> bpcg(pre_condition.data());
bpcg.init_iter(nband, nbasis);
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue);
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band);
}
else if (this->method == "dav_subspace")
{
Expand Down
9 changes: 5 additions & 4 deletions source/module_hsolver/test/diago_bpcg_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,11 @@ class DiagoBPCGPrepare
hpsi_out, ld_psi);
};
bpcg.init_iter(nband, npw);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en);
std::vector<double> ethr_band(nband, 1e-5);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
bpcg.diag(hpsi_func, psi_local.get_pointer(), en, ethr_band);
end = MPI_Wtime();
//if(mypnum == 0) printf("diago time:%7.3f\n",end-start);
delete [] DIAGOTEST::npw_local;
Expand Down
Loading