Skip to content

Commit 9d4a7c9

Browse files
committed
Refactor hpsi_func of dav_subspace
1 parent 03742af commit 9d4a7c9

File tree

4 files changed

+16
-19
lines changed

4 files changed

+16
-19
lines changed

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
124124

125125
// compute h*psi_in_iter
126126
// NOTE: bands after the first n_band should yield zero
127-
hpsi_func(this->psi_in_iter, this->hphi, this->nbase_x, this->dim, 0, this->nbase_x - 1);
127+
hpsi_func(this->psi_in_iter, this->hphi, this->dim, this->nbase_x);
128128

129129
// at this stage, notconv = n_band and nbase = 0
130130
// note that nbase of cal_elem is an inout parameter: nbase := nbase + notconv
@@ -421,7 +421,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
421421
}
422422

423423
// update hpsi[:, nbase:nbase+notconv]
424-
hpsi_func(psi_iter, &hphi[nbase * this->dim], this->nbase_x, this->dim, nbase, nbase + notconv - 1);
424+
hpsi_func(psi_iter + nbase * dim, hphi + nbase * this->dim, this->dim, notconv);
425425

426426
ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
427427
return;
@@ -886,7 +886,7 @@ void Diago_DavSubspace<T, Device>::diagH_subspace(T* psi_pointer, // [in] & [out
886886

887887
{
888888
// do hPsi for all bands
889-
hpsi_func(psi_pointer, hphi, n_band, dmax, 0, nstart - 1);
889+
hpsi_func(psi_pointer, hphi, dmax, nstart);
890890

891891
gemm_op<T, Device>()(ctx,
892892
'C',

source/module_hsolver/diago_dav_subspace.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class Diago_DavSubspace : public DiagH<T, Device>
3131

3232
virtual ~Diago_DavSubspace() override;
3333

34-
using HPsiFunc = std::function<void(T*, T*, const int, const int, const int, const int)>;
34+
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
3535

3636
int diag(const HPsiFunc& hpsi_func,
3737
T* psi_in,

source/module_hsolver/hsolver_pw.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -434,18 +434,17 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
434434
else if (this->method == "dav_subspace")
435435
{
436436
auto ngk_pointer = psi.get_ngk_pointer();
437-
auto hpsi_func = [hm, ngk_pointer](T* psi_in,
438-
T* hpsi_out,
439-
const int nband_in,
440-
const int nbasis_in,
441-
const int band_index1,
442-
const int band_index2) {
437+
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
438+
auto hpsi_func = [hm, ngk_pointer](T *psi_in,
439+
T *hpsi_out,
440+
const int ldPsi,
441+
const int nvec) {
443442
ModuleBase::timer::tick("DavSubspace", "hpsi_func");
444443

445444
// Convert "pointer data stucture" to a psi::Psi object
446-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nband_in, nbasis_in, ngk_pointer);
445+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ldPsi, ngk_pointer);
447446

448-
psi::Range bands_range(true, 0, band_index1, band_index2);
447+
psi::Range bands_range(true, 0, 0, nvec-1);
449448

450449
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
451450
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);

source/module_lr/hsolver_lrtd.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,14 @@ namespace LR
118118
false, //always do the subspace diag (check the implementation)
119119
comm_info);
120120

121-
std::function<void(T*, T*, const int, const int, const int, const int)> hpsi_func = [pHamilt](
121+
auto hpsi_func = [pHamilt](
122122
T* psi_in,
123123
T* hpsi_out,
124-
const int nband_in,
125-
const int nbasis_in,
126-
const int band_index1,
127-
const int band_index2)
124+
const int ldPsi,
125+
const int nvec)
128126
{
129-
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nband_in, nbasis_in, nullptr);
130-
psi::Range bands_range(true, 0, band_index1, band_index2);
127+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ldPsi, nullptr);
128+
psi::Range bands_range(true, 0, 0, nvec-1);
131129
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
132130
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
133131
pHamilt->ops->hPsi(info);

0 commit comments

Comments
 (0)