Skip to content

Commit 8b2ec37

Browse files
committed
add smooth threshold for cg
1 parent f8ef50f commit 8b2ec37

File tree

7 files changed

+36
-15
lines changed

7 files changed

+36
-15
lines changed

source/module_hsolver/diago_cg.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ DiagoCG<T, Device>::~DiagoCG()
5454
}
5555

5656
template <typename T, typename Device>
57-
void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in, ct::Tensor& psi, ct::Tensor& eigen)
57+
void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in,
58+
ct::Tensor& psi,
59+
ct::Tensor& eigen,
60+
const std::vector<double>& ethr_band)
5861
{
5962
ModuleBase::TITLE("DiagoCG", "diag_once");
6063
ModuleBase::timer::tick("DiagoCG", "diag_once");
@@ -153,6 +156,7 @@ void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in, ct::Tensor& psi, c
153156
converged = this->update_psi(pphi,
154157
cg,
155158
scg, // const Tensor&
159+
ethr_band[m],
156160
cg_norm,
157161
theta,
158162
eigen_pack[m], // Real&
@@ -392,6 +396,7 @@ template <typename T, typename Device>
392396
bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
393397
const ct::Tensor& cg,
394398
const ct::Tensor& scg,
399+
const double& ethreshold,
395400
Real& cg_norm,
396401
Real& theta,
397402
Real& eigen,
@@ -441,7 +446,7 @@ bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
441446
cg.data<T>(),
442447
sint_norm);
443448

444-
if (std::abs(eigen - e0) < pw_diag_thr_)
449+
if (std::abs(eigen - e0) < ethreshold)
445450
{
446451
// ModuleBase::timer::tick("DiagoCG","update");
447452
return true;
@@ -582,6 +587,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
582587
const Func& spsi_func,
583588
ct::Tensor& psi,
584589
ct::Tensor& eigen,
590+
const std::vector<double>& ethr_band,
585591
const ct::Tensor& prec)
586592
{
587593
/// record the times of trying iterative diagonalization
@@ -603,7 +609,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
603609

604610
++ntry;
605611
avg_iter_ += 1.0;
606-
this->diag_mock(prec, psi_temp, eigen);
612+
this->diag_mock(prec, psi_temp, eigen, ethr_band);
607613
} while (this->test_exit_cond(ntry, this->notconv_));
608614

609615
if (this->notconv_ > std::max(5, this->n_band_ / 4))

source/module_hsolver/diago_cg.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ class DiagoCG final
4040
// virtual void init(){};
4141
// refactor hpsi_info
4242
// this is the diag() function for CG method
43-
void diag(const Func& hpsi_func, const Func& spsi_func, ct::Tensor& psi, ct::Tensor& eigen, const ct::Tensor& prec = {});
43+
void diag(const Func& hpsi_func,
44+
const Func& spsi_func,
45+
ct::Tensor& psi,
46+
ct::Tensor& eigen,
47+
const std::vector<double>& ethr_band,
48+
const ct::Tensor& prec = {});
4449

4550
private:
4651
Device * ctx_ = {};
@@ -103,6 +108,7 @@ class DiagoCG final
103108
const ct::Tensor& pphi,
104109
const ct::Tensor& cg,
105110
const ct::Tensor& scg,
111+
const double& ethreshold,
106112
Real &cg_norm,
107113
Real &theta,
108114
Real &eigen,
@@ -113,7 +119,10 @@ class DiagoCG final
113119
void schmit_orth(const int& m, const ct::Tensor& psi, const ct::Tensor& sphi, ct::Tensor& phi_m);
114120

115121
// used in diag() for template replace Hamilt with Hamilt_PW
116-
void diag_mock(const ct::Tensor& prec, ct::Tensor& psi, ct::Tensor& eigen);
122+
void diag_mock(const ct::Tensor& prec,
123+
ct::Tensor& psi,
124+
ct::Tensor& eigen,
125+
const std::vector<double>& ethr_band);
117126

118127
bool test_exit_cond(const int& ntry, const int& notconv) const;
119128

source/module_hsolver/hsolver_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
467467
.to_device<ct_Device>()
468468
.slice({0}, {psi.get_current_nbas()});
469469

470-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
470+
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor);
471471
// TODO: Double check tensormap's potential problem
472472
ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
473473
}

source/module_hsolver/test/diago_cg_float_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ class DiagoCGPrepare
193193
ct::DataType::DT_FLOAT,
194194
ct::DeviceType::CpuDevice,
195195
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});
196-
197-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
196+
197+
std::vector<double> ethr_band(nband, 1e-5);
198+
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
198199
// TODO: Double check tensormap's potential problem
199200
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
200201
/**************************************************************/

source/module_hsolver/test/diago_cg_real_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,9 @@ class DiagoCGPrepare
196196
ct::DataType::DT_DOUBLE,
197197
ct::DeviceType::CpuDevice,
198198
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});
199-
200-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
199+
200+
std::vector<double> ethr_band(nband, 1e-5);
201+
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
201202
// TODO: Double check tensormap's potential problem
202203
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
203204
/**************************************************************/

source/module_hsolver/test/diago_cg_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,9 @@ class DiagoCGPrepare
187187
ct::DataType::DT_DOUBLE,
188188
ct::DeviceType::CpuDevice,
189189
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});
190-
191-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
190+
191+
std::vector<double> ethr_band(nband, 1e-5);
192+
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
192193
// TODO: Double check tensormap's potential problem
193194
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
194195
/**************************************************************/

source/module_lr/hsolver_lrtd.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,12 @@ namespace LR
146146
std::vector<Real<T>> precondition_(precondition); //since TensorMap does not support const pointer
147147
auto precon_tensor = ct::TensorMap(precondition_.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ dim }));
148148
auto hpsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& hpsi) {hm.hPsi(psi_in.data<T>(), hpsi.data<T>(), psi_in.shape().dim_size(0) /*nbasis_local*/, 1/*band-by-band*/);};
149-
auto spsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& spsi)
150-
{ std::memcpy(spsi.data<T>(), psi_in.data<T>(), sizeof(T) * psi_in.NumElements()); };
151-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, precon_tensor);
149+
auto spsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& spsi) {
150+
std::memcpy(spsi.data<T>(), psi_in.data<T>(), sizeof(T) * psi_in.NumElements());
151+
};
152+
153+
std::vector<double> ethr_band(nband, diag_ethr);
154+
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, precon_tensor);
152155
}
153156
else { throw std::runtime_error("HSolverLR::solve: method not implemented"); }
154157
}

0 commit comments

Comments
 (0)