Skip to content

Commit 6646e66

Browse files
committed
Update docs for new hpsi_func
1 parent 65a5cd0 commit 6646e66

File tree

5 files changed

+14
-8
lines changed

5 files changed

+14
-8
lines changed

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
124124

125125
// compute h*psi_in_iter
126126
// NOTE: bands after the first n_band should yield zero
127+
// hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
127128
hpsi_func(this->psi_in_iter, this->hphi, this->dim, this->nbase_x);
128129

129130
// at this stage, notconv = n_band and nbase = 0
@@ -421,6 +422,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
421422
}
422423

423424
// update hpsi[:, nbase:nbase+notconv]
425+
// hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
424426
hpsi_func(psi_iter + nbase * dim, hphi + nbase * this->dim, this->dim, notconv);
425427

426428
ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
@@ -886,6 +888,7 @@ void Diago_DavSubspace<T, Device>::diagH_subspace(T* psi_pointer, // [in] & [out
886888

887889
{
888890
// do hPsi for all bands
891+
// hphi[:, 0:nstart] = H * psi_pointer[:, 0:nstart]
889892
hpsi_func(psi_pointer, hphi, dmax, nstart);
890893

891894
gemm_op<T, Device>()(ctx,

source/module_hsolver/diago_dav_subspace.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Diago_DavSubspace : public DiagH<T, Device>
3131

3232
virtual ~Diago_DavSubspace() override;
3333

34+
// See diago_david.h for information on the HPsiFunc function type
3435
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
3536

3637
int diag(const HPsiFunc& hpsi_func,

source/module_hsolver/diago_david.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
230230
// end of SchmidtOrth and calculate H|psi>
231231
// hpsi_info dav_hpsi_in(&basis, psi::Range(true, 0, 0, nband - 1), this->hpsi);
232232
// phm_in->ops->hPsi(dav_hpsi_in);
233+
// hpsi[:, 0:nband] = H basis[:, 0:nband]
234+
// slice index in this piece of code is in C manner. i.e. 0:id stands for [0,id)
233235
hpsi_func(basis, hpsi, dim, nband);
234236

235237
this->cal_elem(dim, nbase, nbase_x, this->notconv, this->hpsi, this->spsi, this->hcc, this->scc);
@@ -601,6 +603,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
601603
// psi::Range(true, 0, nbase, nbase + notconv - 1),
602604
// &hpsi[nbase * dim]); // &hp(nbase, 0)
603605
// phm_in->ops->hPsi(dav_hpsi_in);
606+
// hpsi[:, nbase:nbase+notcnv] = H basis[:, nbase:nbase+notcnv]
604607
hpsi_func(basis + nbase * dim, hpsi + nbase * dim, dim, notconv);
605608

606609
delmem_complex_op()(this->ctx, lagrange);

source/module_hsolver/diago_david.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,13 @@ class DiagoDavid : public DiagH<T, Device>
3838
* this function computes the product of the Hamiltonian matrix H and a blockvector X.
3939
*
4040
* Called as follows:
41-
* hpsi(X, HX, nvec, dim, id_start, id_end)
42-
* Result is stored in HX.
43-
* HX = H * X[id_start:id_end]
41+
* hpsi(X, HX, ld, nvec) where X and HX are (ld, nvec)-shaped blockvectors.
42+
* Result HX = H * X is stored in HX.
4443
*
4544
* @param[out] X Head address of input blockvector of type `T*`.
46-
* @param[in] HX Where to write output blockvector of type `T*`.
47-
* @param[in] ld Leading dimension of matrix.
48-
* @param[in] nvec Number of eigebpairs, i.e. number of vectors in a block.
45+
* @param[in] HX Head address of output blockvector of type `T*`.
46+
* @param[in] ld Leading dimension of matrix.
47+
* @param[in] nvec Number of eigenpairs, i.e. number of vectors in a block.
4948
*
5049
* @warning X and HX are the exact address to read input X and store output H*X,
5150
* @warning both of size ld * nvec.

source/module_hsolver/hsolver_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,14 +491,14 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
491491

492492
auto ngk_pointer = psi.get_ngk_pointer();
493493
/// wrap hpsi into lambda function, Matrix \times blockvector
494-
/// hpsi(X, HX, nband, dim, band_index1, band_index2)
494+
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
495495
auto hpsi_func = [hm, ngk_pointer](T *psi_in,
496496
T *hpsi_out,
497497
const int ldPsi,
498498
const int nvec) {
499499
ModuleBase::timer::tick("David", "hpsi_func");
500500

501-
// Convert "pointer data stucture" to a psi::Psi object
501+
// Convert pointer of psi_in to a psi::Psi object
502502
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ldPsi, ngk_pointer);
503503

504504
psi::Range bands_range(true, 0, 0, nvec-1);

0 commit comments

Comments
 (0)