Skip to content

Commit 70d4c1f

Browse files
committed
Refactor hpsi_func of dav
1 parent a89c7ae commit 70d4c1f

File tree

3 files changed

+13
-21
lines changed

3 files changed

+13
-21
lines changed

source/module_hsolver/diago_david.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
230230
// end of SchmidtOrth and calculate H|psi>
231231
// hpsi_info dav_hpsi_in(&basis, psi::Range(true, 0, 0, nband - 1), this->hpsi);
232232
// phm_in->ops->hPsi(dav_hpsi_in);
233-
hpsi_func(basis, hpsi, nbase_x, dim, 0, nband - 1);
233+
hpsi_func(basis, hpsi, dim, nband);
234234

235235
this->cal_elem(dim, nbase, nbase_x, this->notconv, this->hpsi, this->spsi, this->hcc, this->scc);
236236

@@ -601,7 +601,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
601601
// psi::Range(true, 0, nbase, nbase + notconv - 1),
602602
// &hpsi[nbase * dim]); // &hp(nbase, 0)
603603
// phm_in->ops->hPsi(dav_hpsi_in);
604-
hpsi_func(basis, &hpsi[nbase * dim], nbase_x, dim, nbase, nbase + notconv - 1);
604+
hpsi_func(basis + nbase * dim, hpsi + nbase * dim, dim, notconv);
605605

606606
delmem_complex_op()(this->ctx, lagrange);
607607
delmem_complex_op()(this->ctx, vc_ev_vector);
@@ -1149,9 +1149,7 @@ void DiagoDavid<T, Device>::planSchmidtOrth(const int nband, std::vector<int>& p
11491149
/**
11501150
* @brief Performs iterative diagonalization using the David algorithm.
11511151
*
1152-
* @warning Please see docs of `HPsiFunc` for more information.
1153-
* @warning Please adhere strictly to the requirements of the function pointer
1154-
* @warning for the hpsi mat-vec interface; it may seem counterintuitive.
1152+
* @warning Please see docs of `HPsiFunc` for more information about the hpsi mat-vec interface.
11551153
*
11561154
* @tparam T The type of the elements in the matrix.
11571155
* @tparam Device The device type (CPU or GPU).

source/module_hsolver/diago_david.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,13 @@ class DiagoDavid : public DiagH<T, Device>
4444
*
4545
* @param[out] X Head address of input blockvector of type `T*`.
4646
* @param[in] HX Where to write output blockvector of type `T*`.
47+
* @param[in] ld Leading dimension of matrix.
4748
* @param[in] nvec Number of eigebpairs, i.e. number of vectors in a block.
48-
* @param[in] dim Dimension of matrix.
49-
* @param[in] id_start Start index of blockvector.
50-
* @param[in] id_end End index of blockvector.
5149
*
52-
* @warning HX is the exact address to store output H*X[id_start:id_end];
53-
* @warning while X is the head address of input blockvector, \b without offset.
54-
* @warning Calling function should pass X and HX[offset] as arguments,
55-
* @warning where offset is usually id_start * leading dimension.
50+
* @warning X and HX are the exact address to read input X and store output H*X,
51+
* @warning both of size ld * nvec.
5652
*/
57-
using HPsiFunc = std::function<void(T*, T*, const int, const int, const int, const int)>;
53+
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
5854

5955
/**
6056
* @brief A function type representing the SX function.

source/module_hsolver/hsolver_pw.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -492,18 +492,16 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
492492
auto ngk_pointer = psi.get_ngk_pointer();
493493
/// wrap hpsi into lambda function, Matrix \times blockvector
494494
/// hpsi(X, HX, nband, dim, band_index1, band_index2)
495-
auto hpsi_func = [hm, ngk_pointer](T* psi_in,
496-
T* hpsi_out,
497-
const int nband_in,
498-
const int nbasis_in,
499-
const int band_index1,
500-
const int band_index2) {
495+
auto hpsi_func = [hm, ngk_pointer](T *psi_in,
496+
T *hpsi_out,
497+
const int ldPsi,
498+
const int nvec) {
501499
ModuleBase::timer::tick("David", "hpsi_func");
502500

503501
// Convert "pointer data stucture" to a psi::Psi object
504-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nband_in, nbasis_in, ngk_pointer);
502+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ldPsi, ngk_pointer);
505503

506-
psi::Range bands_range(true, 0, band_index1, band_index2);
504+
psi::Range bands_range(true, 0, 0, nvec-1);
507505

508506
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
509507
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);

0 commit comments

Comments
 (0)