Skip to content

Commit 72b1d7c

Browse files
Refactor hpsi_func in hsolver (#5202)
* Refactor hpsi_func of dav_subspace * Modified the hpsi_func in pyabacus to maintain definition consistency * Fix hpsi_func in pyabacus-dav_subspace * [pre-commit.ci lite] apply automatic fixes * Refactor hpsi_func of dav * Change hpsi_func of hsolver_lrtd * Modify hpsi_func definition in dav tests * Modify the hpsi_func in pyabacus to maintain definition consistency * [pre-commit.ci lite] apply automatic fixes * Modify hsolver_pw_sup mock func signature * Update docs for new hpsi_func * Update docs * Change indent to make it prettier * Update docs * Update spsi docs * Update parameter name of spsi_func interface * Rename leading dimension vars from camel to snake case * Rename leading dimension in pyabacus to snake case --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 800987a commit 72b1d7c

File tree

12 files changed

+135
-137
lines changed

12 files changed

+135
-137
lines changed

python/pyabacus/src/py_diago_dav_subspace.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,21 @@ class PyDiagoDavSubspace
113113
auto hpsi_func = [mm_op] (
114114
std::complex<double> *psi_in,
115115
std::complex<double> *hpsi_out,
116-
const int nband_in,
117-
const int nbasis_in,
118-
const int band_index1,
119-
const int band_index2
116+
const int ld_psi,
117+
const int nvec
120118
) {
121119
// Note: numpy's py::array_t is row-major, but
122120
// our raw pointer-array is column-major
123-
py::array_t<std::complex<double>, py::array::f_style> psi({nbasis_in, band_index2 - band_index1 + 1});
121+
py::array_t<std::complex<double>, py::array::f_style> psi({ld_psi, nvec});
124122
py::buffer_info psi_buf = psi.request();
125123
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
126-
std::copy(psi_in + band_index1 * nbasis_in, psi_in + (band_index2 + 1) * nbasis_in, psi_ptr);
124+
std::copy(psi_in, psi_in + nvec * ld_psi, psi_ptr);
127125

128126
py::array_t<std::complex<double>, py::array::f_style> hpsi = mm_op(psi);
129127

130128
py::buffer_info hpsi_buf = hpsi.request();
131129
std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
132-
std::copy(hpsi_ptr, hpsi_ptr + (band_index2 - band_index1 + 1) * nbasis_in, hpsi_out);
130+
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
133131
};
134132

135133
obj = std::make_unique<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>>(

python/pyabacus/src/py_diago_david.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,23 +111,21 @@ class PyDiagoDavid
111111
auto hpsi_func = [mm_op] (
112112
std::complex<double> *psi_in,
113113
std::complex<double> *hpsi_out,
114-
const int nband_in,
115-
const int nbasis_in,
116-
const int band_index1,
117-
const int band_index2
114+
const int ld_psi,
115+
const int nvec
118116
) {
119117
// Note: numpy's py::array_t is row-major, but
120118
// our raw pointer-array is column-major
121-
py::array_t<std::complex<double>, py::array::f_style> psi({nbasis_in, band_index2 - band_index1 + 1});
119+
py::array_t<std::complex<double>, py::array::f_style> psi({ld_psi, nvec});
122120
py::buffer_info psi_buf = psi.request();
123121
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
124-
std::copy(psi_in + band_index1 * nbasis_in, psi_in + (band_index2 + 1) * nbasis_in, psi_ptr);
122+
std::copy(psi_in, psi_in + nvec * ld_psi, psi_ptr);
125123

126124
py::array_t<std::complex<double>, py::array::f_style> hpsi = mm_op(psi);
127125

128126
py::buffer_info hpsi_buf = hpsi.request();
129127
std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
130-
std::copy(hpsi_ptr, hpsi_ptr + (band_index2 - band_index1 + 1) * nbasis_in, hpsi_out);
128+
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
131129
};
132130

133131
auto spsi_func = [this] (

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
124124

125125
// compute h*psi_in_iter
126126
// NOTE: bands after the first n_band should yield zero
127-
hpsi_func(this->psi_in_iter, this->hphi, this->nbase_x, this->dim, 0, this->nbase_x - 1);
127+
// hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
128+
hpsi_func(this->psi_in_iter, this->hphi, this->dim, this->nbase_x);
128129

129130
// at this stage, notconv = n_band and nbase = 0
130131
// note that nbase of cal_elem is an inout parameter: nbase := nbase + notconv
@@ -421,7 +422,8 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
421422
}
422423

423424
// update hpsi[:, nbase:nbase+notconv]
424-
hpsi_func(psi_iter, &hphi[nbase * this->dim], this->nbase_x, this->dim, nbase, nbase + notconv - 1);
425+
// hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
426+
hpsi_func(psi_iter + nbase * dim, hphi + nbase * this->dim, this->dim, notconv);
425427

426428
ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
427429
return;
@@ -886,7 +888,8 @@ void Diago_DavSubspace<T, Device>::diagH_subspace(T* psi_pointer, // [in] & [out
886888

887889
{
888890
// do hPsi for all bands
889-
hpsi_func(psi_pointer, hphi, n_band, dmax, 0, nstart - 1);
891+
// hphi[:, 0:nstart] = H * psi_pointer[:, 0:nstart]
892+
hpsi_func(psi_pointer, hphi, dmax, nstart);
890893

891894
gemm_op<T, Device>()(ctx,
892895
'C',

source/module_hsolver/diago_dav_subspace.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class Diago_DavSubspace : public DiagH<T, Device>
3131

3232
virtual ~Diago_DavSubspace() override;
3333

34-
using HPsiFunc = std::function<void(T*, T*, const int, const int, const int, const int)>;
34+
// See diago_david.h for information on the HPsiFunc function type
35+
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
3536

3637
int diag(const HPsiFunc& hpsi_func,
3738
T* psi_in,

source/module_hsolver/diago_david.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
152152
const SPsiFunc& spsi_func,
153153
const int dim,
154154
const int nband,
155-
const int ldPsi,
155+
const int ld_psi,
156156
T *psi_in,
157157
Real* eigenvalue_in,
158158
const Real david_diag_thr,
@@ -191,20 +191,20 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
191191
if(this->use_paw)
192192
{
193193
#ifdef USE_PAW
194-
GlobalC::paw_cell.paw_nl_psi(1, reinterpret_cast<const std::complex<double>*> (psi_in + m*ldPsi),
194+
GlobalC::paw_cell.paw_nl_psi(1, reinterpret_cast<const std::complex<double>*> (psi_in + m*ld_psi),
195195
reinterpret_cast<std::complex<double>*>(&this->spsi[m * dim]));
196196
#endif
197197
}
198198
else
199199
{
200-
// phm_in->sPsi(psi_in + m*ldPsi, &this->spsi[m * dim], dim, dim, 1);
201-
spsi_func(psi_in + m*ldPsi,&this->spsi[m*dim],dim,dim,1);
200+
// phm_in->sPsi(psi_in + m*ld_psi, &this->spsi[m * dim], dim, dim, 1);
201+
spsi_func(psi_in + m*ld_psi,&this->spsi[m*dim],dim,dim,1);
202202
}
203203
}
204204
// begin SchmidtOrth
205205
for (int m = 0; m < nband; m++)
206206
{
207-
syncmem_complex_op()(this->ctx, this->ctx, basis + dim*m, psi_in + m*ldPsi, dim);
207+
syncmem_complex_op()(this->ctx, this->ctx, basis + dim*m, psi_in + m*ld_psi, dim);
208208

209209
this->SchmidtOrth(dim,
210210
nband,
@@ -230,7 +230,9 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
230230
// end of SchmidtOrth and calculate H|psi>
231231
// hpsi_info dav_hpsi_in(&basis, psi::Range(true, 0, 0, nband - 1), this->hpsi);
232232
// phm_in->ops->hPsi(dav_hpsi_in);
233-
hpsi_func(basis, hpsi, nbase_x, dim, 0, nband - 1);
233+
// hpsi[:, 0:nband] = H basis[:, 0:nband]
234+
// slice index in this piece of code is in C manner. i.e. 0:id stands for [0,id)
235+
hpsi_func(basis, hpsi, dim, nband);
234236

235237
this->cal_elem(dim, nbase, nbase_x, this->notconv, this->hpsi, this->spsi, this->hcc, this->scc);
236238

@@ -287,7 +289,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
287289

288290
// update eigenvectors of Hamiltonian
289291

290-
setmem_complex_op()(this->ctx, psi_in, 0, nband * ldPsi);
292+
setmem_complex_op()(this->ctx, psi_in, 0, nband * ld_psi);
291293
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
292294
gemm_op<T, Device>()(this->ctx,
293295
'N',
@@ -302,7 +304,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
302304
nbase_x,
303305
this->zero,
304306
psi_in, // C dim * nband
305-
ldPsi
307+
ld_psi
306308
);
307309

308310
if (!this->notconv || (dav_iter == david_maxiter))
@@ -322,7 +324,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
322324
nbase_x,
323325
eigenvalue_in,
324326
psi_in,
325-
ldPsi,
327+
ld_psi,
326328
this->hpsi,
327329
this->spsi,
328330
this->hcc,
@@ -601,7 +603,8 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
601603
// psi::Range(true, 0, nbase, nbase + notconv - 1),
602604
// &hpsi[nbase * dim]); // &hp(nbase, 0)
603605
// phm_in->ops->hPsi(dav_hpsi_in);
604-
hpsi_func(basis, &hpsi[nbase * dim], nbase_x, dim, nbase, nbase + notconv - 1);
606+
// hpsi[:, nbase:nbase+notcnv] = H basis[:, nbase:nbase+notcnv]
607+
hpsi_func(basis + nbase * dim, hpsi + nbase * dim, dim, notconv);
605608

606609
delmem_complex_op()(this->ctx, lagrange);
607610
delmem_complex_op()(this->ctx, vc_ev_vector);
@@ -785,7 +788,7 @@ void DiagoDavid<T, Device>::diag_zhegvx(const int& nbase,
785788
* @param nbase_x The maximum dimension of the reduced basis set.
786789
* @param eigenvalue_in Pointer to the array of eigenvalues.
787790
* @param psi_in Pointer to the array of wavefunctions.
788-
* @param ldPsi The leading dimension of the wavefunction array.
791+
* @param ld_psi The leading dimension of the wavefunction array.
789792
* @param hpsi Pointer to the output array for the updated basis set.
790793
* @param spsi Pointer to the output array for the updated basis set (nband-th column).
791794
* @param hcc Pointer to the output array for the updated reduced Hamiltonian.
@@ -800,7 +803,7 @@ void DiagoDavid<T, Device>::refresh(const int& dim,
800803
const int nbase_x, // maximum dimension of the reduced basis set
801804
const Real* eigenvalue_in,
802805
const T *psi_in,
803-
const int ldPsi,
806+
const int ld_psi,
804807
T* hpsi,
805808
T* spsi,
806809
T* hcc,
@@ -866,7 +869,7 @@ void DiagoDavid<T, Device>::refresh(const int& dim,
866869

867870
for (int m = 0; m < nband; m++)
868871
{
869-
syncmem_complex_op()(this->ctx, this->ctx, basis + dim*m,psi_in + m*ldPsi, dim);
872+
syncmem_complex_op()(this->ctx, this->ctx, basis + dim*m,psi_in + m*ld_psi, dim);
870873
/*for (int ig = 0; ig < npw; ig++)
871874
basis(m, ig) = psi(m, ig);*/
872875
}
@@ -1149,15 +1152,13 @@ void DiagoDavid<T, Device>::planSchmidtOrth(const int nband, std::vector<int>& p
11491152
/**
11501153
* @brief Performs iterative diagonalization using the David algorithm.
11511154
*
1152-
* @warning Please see docs of `HPsiFunc` for more information.
1153-
* @warning Please adhere strictly to the requirements of the function pointer
1154-
* @warning for the hpsi mat-vec interface; it may seem counterintuitive.
1155+
* @warning Please see docs of `HPsiFunc` for more information about the hpsi mat-vec interface.
11551156
*
11561157
* @tparam T The type of the elements in the matrix.
11571158
* @tparam Device The device type (CPU or GPU).
11581159
* @param hpsi_func The function object that computes the matrix-blockvector product H * psi.
11591160
* @param spsi_func The function object that computes the matrix-blockvector product overlap S * psi.
1160-
* @param ldPsi The leading dimension of the psi_in array.
1161+
* @param ld_psi The leading dimension of the psi_in array.
11611162
* @param psi_in The input wavefunction.
11621163
* @param eigenvalue_in The array to store the eigenvalues.
11631164
* @param david_diag_thr The convergence threshold for the diagonalization.
@@ -1172,7 +1173,7 @@ void DiagoDavid<T, Device>::planSchmidtOrth(const int nband, std::vector<int>& p
11721173
template <typename T, typename Device>
11731174
int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
11741175
const SPsiFunc& spsi_func,
1175-
const int ldPsi,
1176+
const int ld_psi,
11761177
T *psi_in,
11771178
Real* eigenvalue_in,
11781179
const Real david_diag_thr,
@@ -1187,7 +1188,7 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
11871188
int sum_dav_iter = 0;
11881189
do
11891190
{
1190-
sum_dav_iter += this->diag_once(hpsi_func, spsi_func, dim, nband, ldPsi, psi_in, eigenvalue_in, david_diag_thr, david_maxiter);
1191+
sum_dav_iter += this->diag_once(hpsi_func, spsi_func, dim, nband, ld_psi, psi_in, eigenvalue_in, david_diag_thr, david_maxiter);
11911192
++ntry;
11921193
} while (!check_block_conv(ntry, this->notconv, ntry_max, notconv_max));
11931194

source/module_hsolver/diago_david.h

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,50 +38,48 @@ class DiagoDavid : public DiagH<T, Device>
3838
* this function computes the product of the Hamiltonian matrix H and a blockvector X.
3939
*
4040
* Called as follows:
41-
* hpsi(X, HX, nvec, dim, id_start, id_end)
42-
* Result is stored in HX.
43-
* HX = H * X[id_start:id_end]
41+
* hpsi(X, HX, ld, nvec) where X and HX are (ld, nvec)-shaped blockvectors.
42+
* Result HX = H * X is stored in HX.
4443
*
4544
* @param[out] X Head address of input blockvector of type `T*`.
46-
* @param[in] HX Where to write output blockvector of type `T*`.
47-
* @param[in] nvec Number of eigebpairs, i.e. number of vectors in a block.
48-
* @param[in] dim Dimension of matrix.
49-
* @param[in] id_start Start index of blockvector.
50-
* @param[in] id_end End index of blockvector.
45+
* @param[in] HX Head address of output blockvector of type `T*`.
46+
* @param[in] ld Leading dimension of blockvector.
47+
* @param[in] nvec Number of vectors in a block.
5148
*
52-
* @warning HX is the exact address to store output H*X[id_start:id_end];
53-
* @warning while X is the head address of input blockvector, \b without offset.
54-
* @warning Calling function should pass X and HX[offset] as arguments,
55-
* @warning where offset is usually id_start * leading dimension.
49+
* @warning X and HX are the exact address to read input X and store output H*X,
50+
* @warning both of size ld * nvec.
5651
*/
57-
using HPsiFunc = std::function<void(T*, T*, const int, const int, const int, const int)>;
52+
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
5853

5954
/**
6055
* @brief A function type representing the SX function.
56+
*
57+
* nrow is leading dimension of spsi, npw is leading dimension of psi, nbands is number of vecs
6158
*
6259
* This function type is used to define a matrix-blockvector operator S.
6360
* For generalized eigenvalue problem HX = λSX,
6461
* this function computes the product of the overlap matrix S and a blockvector X.
6562
*
66-
* @param[in] X Pointer to the input array.
67-
* @param[out] SX Pointer to the output array.
68-
* @param[in] nrow Dimension of SX: nbands * nrow.
69-
* @param[in] npw Number of plane waves.
70-
* @param[in] nbands Number of bands.
63+
* @param[in] X Pointer to the input blockvector.
64+
* @param[out] SX Pointer to the output blockvector.
65+
* @param[in] ld_spsi Leading dimension of spsi. Dimension of SX: nbands * nrow.
66+
* @param[in] ld_psi Leading dimension of psi. Number of plane waves.
67+
* @param[in] nbands Number of vectors.
7168
*
72-
* @note called as spsi(in, out, dim, dim, 1)
69+
* @note called like spsi(in, out, dim, dim, 1)
7370
*/
7471
using SPsiFunc = std::function<void(T*, T*, const int, const int, const int)>;
7572

76-
int diag(const HPsiFunc& hpsi_func, // function void hpsi(T*, T*, const int, const int, const int, const int)
77-
const SPsiFunc& spsi_func, // function void spsi(T*, T*, const int, const int, const int)
78-
const int ldPsi, // Leading dimension of the psi input
79-
T *psi_in, // Pointer to eigenvectors
80-
Real* eigenvalue_in, // Pointer to store the resulting eigenvalues
81-
const Real david_diag_thr, // Convergence threshold for the Davidson iteration
82-
const int david_maxiter, // Maximum allowed iterations for the Davidson method
83-
const int ntry_max = 5, // Maximum number of diagonalization attempts (default is 5)
84-
const int notconv_max = 0); // Maximum number of allowed non-converged eigenvectors
73+
int diag(
74+
const HPsiFunc& hpsi_func, // function void hpsi(T*, T*, const int, const int)
75+
const SPsiFunc& spsi_func, // function void spsi(T*, T*, const int, const int, const int)
76+
const int ld_psi, // Leading dimension of the psi input
77+
T *psi_in, // Pointer to eigenvectors
78+
Real* eigenvalue_in, // Pointer to store the resulting eigenvalues
79+
const Real david_diag_thr, // Convergence threshold for the Davidson iteration
80+
const int david_maxiter, // Maximum allowed iterations for the Davidson method
81+
const int ntry_max = 5, // Maximum number of diagonalization attempts (5 by default)
82+
const int notconv_max = 0); // Maximum number of allowed non-converged eigenvectors
8583

8684
private:
8785
bool use_paw = false;
@@ -130,7 +128,7 @@ class DiagoDavid : public DiagH<T, Device>
130128
const SPsiFunc& spsi_func,
131129
const int dim,
132130
const int nband,
133-
const int ldPsi,
131+
const int ld_psi,
134132
T *psi_in,
135133
Real* eigenvalue_in,
136134
const Real david_diag_thr,
@@ -163,20 +161,20 @@ class DiagoDavid : public DiagH<T, Device>
163161
const int nbase_x,
164162
const Real* eigenvalue,
165163
const T *psi_in,
166-
const int ldPsi,
164+
const int ld_psi,
167165
T* hpsi,
168166
T* spsi,
169167
T* hcc,
170168
T* scc,
171169
T* vcc);
172170

173171
void SchmidtOrth(const int& dim,
174-
const int nband,
175-
const int m,
176-
const T* spsi,
177-
T* lagrange_m,
178-
const int mm_size,
179-
const int mv_size);
172+
const int nband,
173+
const int m,
174+
const T* spsi,
175+
T* lagrange_m,
176+
const int mm_size,
177+
const int mv_size);
180178

181179
void planSchmidtOrth(const int nband, std::vector<int>& pre_matrix_mm_m, std::vector<int>& pre_matrix_mv_m);
182180

0 commit comments

Comments
 (0)