Skip to content

Commit ffdebd8

Browse files
committed
Change wrapper spsi_func in hsolver-dav
1 parent 72b1d7c commit ffdebd8

File tree

4 files changed

+11
-16
lines changed

4 files changed

+11
-16
lines changed

python/pyabacus/src/py_diago_david.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ class PyDiagoDavid
132132
const std::complex<double> *psi_in,
133133
std::complex<double> *spsi_out,
134134
const int nrow,
135-
const int npw,
136135
const int nbands
137136
) {
138137
syncmem_op()(this->ctx, this->ctx, spsi_out, psi_in, static_cast<size_t>(nbands * nrow));

source/module_hsolver/diago_david.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
198198
else
199199
{
200200
// phm_in->sPsi(psi_in + m*ld_psi, &this->spsi[m * dim], dim, dim, 1);
201-
spsi_func(psi_in + m*ld_psi,&this->spsi[m*dim],dim,dim,1);
201+
spsi_func(psi_in + m*ld_psi,&this->spsi[m*dim],dim, 1);
202202
}
203203
}
204204
// begin SchmidtOrth
@@ -223,7 +223,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
223223
else
224224
{
225225
// phm_in->sPsi(basis + dim*m, &this->spsi[m * dim], dim, dim, 1);
226-
spsi_func(basis + dim*m, &this->spsi[m * dim], dim, dim, 1);
226+
spsi_func(basis + dim*m, &this->spsi[m * dim], dim, 1);
227227
}
228228
}
229229

@@ -554,7 +554,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
554554
else
555555
{
556556
// phm_in->sPsi(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, dim, 1);
557-
spsi_func(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, dim, 1);
557+
spsi_func(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, 1);
558558
}
559559
}
560560
// first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix
@@ -595,7 +595,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
595595
else
596596
{
597597
// phm_in->sPsi(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, dim, 1);
598-
spsi_func(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, dim, 1);
598+
spsi_func(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, 1);
599599
}
600600
}
601601
// calculate H|psi> for not convergence bands

source/module_hsolver/diago_david.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,10 @@ class DiagoDavid : public DiagH<T, Device>
6262
*
6363
* @param[in] X Pointer to the input blockvector.
6464
* @param[out] SX Pointer to the output blockvector.
65-
* @param[in] ld_spsi Leading dimension of spsi. Dimension of SX: nbands * nrow.
66-
* @param[in] ld_psi Leading dimension of psi. Number of plane waves.
67-
* @param[in] nbands Number of vectors.
68-
*
69-
* @note called like spsi(in, out, dim, dim, 1)
65+
* @param[in] ld_psi Leading dimension of psi and spsi. Dimension of X&SX: ld * nvec.
66+
* @param[in] nvec Number of vectors.
7067
*/
71-
using SPsiFunc = std::function<void(T*, T*, const int, const int, const int)>;
68+
using SPsiFunc = std::function<void(T*, T*, const int, const int)>;
7269

7370
int diag(
7471
const HPsiFunc& hpsi_func, // function void hpsi(T*, T*, const int, const int)

source/module_hsolver/hsolver_pw.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,17 +511,16 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
511511
};
512512

513513
/// wrap spsi into lambda function, Matrix \times blockvector
514-
/// spsi(X, SX, nrow, npw, nbands)
515-
/// nrow is leading dimension of spsi, npw is leading dimension of psi, nbands is number of vecs
514+
/// spsi(X, SX, ld, nvec)
515+
/// ld is leading dimension of psi and spsi
516516
auto spsi_func = [hm](const T* psi_in, T* spsi_out,
517-
const int ld_spsi, // Leading dimension of spsi. Dimension of SX: nbands * nrow.
518-
const int ld_psi, // Leading dimension of psi. Number of plane waves.
517+
const int ld_psi, // Leading dimension of psi and spsi.
519518
const int nvec // Number of vectors(bands)
520519
){
521520
ModuleBase::timer::tick("David", "spsi_func");
522521
// sPsi determines S=I or not by GlobalV::use_uspp inside
523522
// sPsi(psi, spsi, nrow, npw, nbands)
524-
hm->sPsi(psi_in, spsi_out, ld_spsi, ld_psi, nvec);
523+
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
525524
ModuleBase::timer::tick("David", "spsi_func");
526525
};
527526

0 commit comments

Comments
 (0)