Skip to content

Commit e17619d

Browse files
haozhihanFisherd99
authored andcommitted
Refactor: add smooth threshold support for david method (deepmodeling#5697)
* change raw pointer to std::vector * add ethr_band for dav method * change unit test * fix build bug * fix pyabacus * fix pyabacus dav-subspace * fix pyabacus build * fix pyabacus * add & for vector
1 parent 701b3b2 commit e17619d

File tree

14 files changed

+55
-33
lines changed

14 files changed

+55
-33
lines changed

python/pyabacus/src/hsolver/py_diago_dav_subspace.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ class PyDiagoDavSubspace
101101

102102
int diag(
103103
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
104-
std::vector<double> precond_vec,
104+
std::vector<double>& precond_vec,
105105
int dav_ndim,
106106
double tol,
107107
int max_iter,
108108
bool need_subspace,
109-
std::vector<double> diag_ethr,
109+
std::vector<double>& diag_ethr,
110110
bool scf_type,
111111
hsolver::diag_comm_info comm_info
112112
) {
@@ -141,7 +141,7 @@ class PyDiagoDavSubspace
141141
comm_info
142142
);
143143

144-
return obj->diag(hpsi_func, psi, nbasis, eigenvalue, diag_ethr.data(), scf_type);
144+
return obj->diag(hpsi_func, psi, nbasis, eigenvalue, diag_ethr, scf_type);
145145
}
146146

147147
private:

python/pyabacus/src/hsolver/py_diago_david.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,10 @@ class PyDiagoDavid
101101

102102
int diag(
103103
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
104-
std::vector<double> precond_vec,
104+
std::vector<double>& precond_vec,
105105
int dav_ndim,
106106
double tol,
107+
std::vector<double>& diag_ethr,
107108
int max_iter,
108109
bool use_paw,
109110
hsolver::diag_comm_info comm_info
@@ -146,7 +147,7 @@ class PyDiagoDavid
146147
comm_info
147148
);
148149

149-
return obj->diag(hpsi_func, spsi_func, nbasis, psi, eigenvalue, tol, max_iter);
150+
return obj->diag(hpsi_func, spsi_func, nbasis, psi, eigenvalue, diag_ethr, max_iter);
150151
}
151152

152153
private:

python/pyabacus/src/hsolver/py_hsolver.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ void bind_hsolver(py::module& m)
121121
eigenvectors to be calculated.
122122
tol : double
123123
The tolerance for the convergence.
124+
diag_ethr: np.ndarray
125+
The tolerance vector.
124126
max_iter : int
125127
The maximum number of iterations.
126128
use_paw : bool
@@ -130,6 +132,7 @@ void bind_hsolver(py::module& m)
130132
"precond_vec"_a,
131133
"dav_ndim"_a,
132134
"tol"_a,
135+
"diag_ethr"_a,
133136
"max_iter"_a,
134137
"use_paw"_a,
135138
"comm_info"_a)

python/pyabacus/src/pyabacus/hsolver/_hsolver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def davidson(
118118
dav_ndim: int = 2,
119119
tol: float = 1e-2,
120120
max_iter: int = 1000,
121+
diag_ethr: Union[List[float], None] = None,
121122
use_paw: bool = False,
122123
# scf_type: bool = False
123124
) -> Tuple[NDArray[np.float64], NDArray[np.complex128]]:
@@ -143,6 +144,8 @@ def davidson(
143144
The tolerance for the convergence, by default 1e-2.
144145
max_iter : int, optional
145146
The maximum number of iterations, by default 1000.
147+
diag_ethr : List[float] | None, optional
148+
The list of thresholds of bands, by default None.
146149
use_paw : bool, optional
147150
Whether to use projector augmented wave (PAW) method, by default False.
148151
@@ -164,12 +167,16 @@ def davidson(
164167
_diago_obj_david.init_eigenvalue()
165168

166169
comm_info = diag_comm_info(0, 1)
170+
171+
if diag_ethr is None:
172+
diag_ethr = [tol] * num_eigs
167173

168174
_ = _diago_obj_david.diag(
169175
mvv_op,
170176
precondition,
171177
dav_ndim,
172178
tol,
179+
diag_ethr,
173180
max_iter,
174181
use_paw,
175182
comm_info

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
8888
T* psi_in,
8989
const int psi_in_dmax,
9090
Real* eigenvalue_in_hsolver,
91-
const double* ethr_band)
91+
const std::vector<double>& ethr_band)
9292
{
9393
ModuleBase::timer::tick("Diago_DavSubspace", "diag_once");
9494

@@ -726,7 +726,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
726726
T* psi_in,
727727
const int psi_in_dmax,
728728
Real* eigenvalue_in_hsolver,
729-
const double* ethr_band,
729+
const std::vector<double>& ethr_band,
730730
const bool& scf_type)
731731
{
732732
/// record the times of trying iterative diagonalization

source/module_hsolver/diago_dav_subspace.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Diago_DavSubspace
4242
T* psi_in,
4343
const int psi_in_dmax,
4444
Real* eigenvalue_in,
45-
const double* ethr_band,
45+
const std::vector<double>& ethr_band,
4646
const bool& scf_type);
4747

4848
private:
@@ -135,7 +135,7 @@ class Diago_DavSubspace
135135
T* psi_in,
136136
const int psi_in_dmax,
137137
Real* eigenvalue_in,
138-
const double* ethr_band);
138+
const std::vector<double>& ethr_band);
139139

140140
bool test_exit_cond(const int& ntry, const int& notconv, const bool& scf);
141141

source/module_hsolver/diago_david.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
156156
const int ld_psi,
157157
T *psi_in,
158158
Real* eigenvalue_in,
159-
const Real david_diag_thr,
159+
const std::vector<double>& ethr_band,
160160
const int david_maxiter)
161161
{
162162
if (test_david == 1)
@@ -273,7 +273,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
273273
this->notconv = 0;
274274
for (int m = 0; m < nband; m++)
275275
{
276-
convflag[m] = (std::abs(this->eigenvalue[m] - eigenvalue_in[m]) < david_diag_thr);
276+
convflag[m] = (std::abs(this->eigenvalue[m] - eigenvalue_in[m]) < ethr_band[m]);
277277
if (!convflag[m])
278278
{
279279
unconv[this->notconv] = m;
@@ -1177,7 +1177,7 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
11771177
const int ld_psi,
11781178
T *psi_in,
11791179
Real* eigenvalue_in,
1180-
const Real david_diag_thr,
1180+
const std::vector<double>& ethr_band,
11811181
const int david_maxiter,
11821182
const int ntry_max,
11831183
const int notconv_max)
@@ -1189,7 +1189,7 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
11891189
int sum_dav_iter = 0;
11901190
do
11911191
{
1192-
sum_dav_iter += this->diag_once(hpsi_func, spsi_func, dim, nband, ld_psi, psi_in, eigenvalue_in, david_diag_thr, david_maxiter);
1192+
sum_dav_iter += this->diag_once(hpsi_func, spsi_func, dim, nband, ld_psi, psi_in, eigenvalue_in, ethr_band, david_maxiter);
11931193
++ntry;
11941194
} while (!check_block_conv(ntry, this->notconv, ntry_max, notconv_max));
11951195

source/module_hsolver/diago_david.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class DiagoDavid
7979
const int ld_psi, // Leading dimension of the psi input
8080
T *psi_in, // Pointer to eigenvectors
8181
Real* eigenvalue_in, // Pointer to store the resulting eigenvalues
82-
const Real david_diag_thr, // Convergence threshold for the Davidson iteration
82+
const std::vector<double>& ethr_band, // Convergence threshold for the Davidson iteration
8383
const int david_maxiter, // Maximum allowed iterations for the Davidson method
8484
const int ntry_max = 5, // Maximum number of diagonalization attempts (5 by default)
8585
const int notconv_max = 0); // Maximum number of allowed non-converged eigenvectors
@@ -134,7 +134,7 @@ class DiagoDavid
134134
const int ld_psi,
135135
T *psi_in,
136136
Real* eigenvalue_in,
137-
const Real david_diag_thr,
137+
const std::vector<double>& ethr_band,
138138
const int david_maxiter);
139139

140140
void cal_grad(const HPsiFunc& hpsi_func,

source/module_hsolver/hsolver_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
526526
comm_info);
527527

528528
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
529-
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band.data(), scf));
529+
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf));
530530
}
531531
else if (this->method == "dav")
532532
{
@@ -589,7 +589,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
589589
ld_psi,
590590
psi.get_pointer(),
591591
eigenvalue,
592-
david_diag_thr,
592+
this->ethr_band,
593593
david_maxiter,
594594
ntry_max,
595595
notconv_max));

source/module_hsolver/test/diago_david_float_test.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,12 @@ class DiagoDavPrepare
119119
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
120120
phm->ops->hPsi(info);
121121
};
122-
auto spsi_func = [phm](const std::complex<float>* psi_in, std::complex<float>* spsi_out,const int ld_psi, const int nbands){
123-
phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands);
124-
};
125-
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, eps, maxiter);
122+
auto spsi_func = [phm](const std::complex<float>* psi_in,
123+
std::complex<float>* spsi_out,
124+
const int ld_psi,
125+
const int nbands) { phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands); };
126+
std::vector<double> ethr_band(phi.get_nbands(), eps);
127+
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, ethr_band, maxiter);
126128

127129
#ifdef __MPI
128130
end = MPI_Wtime();

0 commit comments

Comments
 (0)