Skip to content

Commit 5329628

Browse files
Change wrapper spsi_func in hsolver-dav (#5205)
* Change wrapper spsi_func in hsolver-dav * Update spsi_func in hsolver_lrtd * Update david test spsi_func * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent e1958cb commit 5329628

File tree

9 files changed

+34
-32
lines changed

9 files changed

+34
-32
lines changed

python/pyabacus/src/py_diago_david.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ class PyDiagoDavid
132132
const std::complex<double> *psi_in,
133133
std::complex<double> *spsi_out,
134134
const int nrow,
135-
const int npw,
136135
const int nbands
137136
) {
138137
syncmem_op()(this->ctx, this->ctx, spsi_out, psi_in, static_cast<size_t>(nbands * nrow));

source/module_hsolver/diago_david.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
198198
else
199199
{
200200
// phm_in->sPsi(psi_in + m*ld_psi, &this->spsi[m * dim], dim, dim, 1);
201-
spsi_func(psi_in + m*ld_psi,&this->spsi[m*dim],dim,dim,1);
201+
spsi_func(psi_in + m*ld_psi,&this->spsi[m*dim],dim, 1);
202202
}
203203
}
204204
// begin SchmidtOrth
@@ -223,7 +223,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
223223
else
224224
{
225225
// phm_in->sPsi(basis + dim*m, &this->spsi[m * dim], dim, dim, 1);
226-
spsi_func(basis + dim*m, &this->spsi[m * dim], dim, dim, 1);
226+
spsi_func(basis + dim*m, &this->spsi[m * dim], dim, 1);
227227
}
228228
}
229229

@@ -554,7 +554,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
554554
else
555555
{
556556
// phm_in->sPsi(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, dim, 1);
557-
spsi_func(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, dim, 1);
557+
spsi_func(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, 1);
558558
}
559559
}
560560
// first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix
@@ -595,7 +595,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
595595
else
596596
{
597597
// phm_in->sPsi(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, dim, 1);
598-
spsi_func(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, dim, 1);
598+
spsi_func(basis + dim*(nbase + m), &spsi[(nbase + m) * dim], dim, 1);
599599
}
600600
}
601601
// calculate H|psi> for not convergence bands

source/module_hsolver/diago_david.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,10 @@ class DiagoDavid : public DiagH<T, Device>
6262
*
6363
* @param[in] X Pointer to the input blockvector.
6464
* @param[out] SX Pointer to the output blockvector.
65-
* @param[in] ld_spsi Leading dimension of spsi. Dimension of SX: nbands * nrow.
66-
* @param[in] ld_psi Leading dimension of psi. Number of plane waves.
67-
* @param[in] nbands Number of vectors.
68-
*
69-
* @note called like spsi(in, out, dim, dim, 1)
65+
* @param[in] ld_psi Leading dimension of psi and spsi. Dimension of X&SX: ld * nvec.
66+
* @param[in] nvec Number of vectors.
7067
*/
71-
using SPsiFunc = std::function<void(T*, T*, const int, const int, const int)>;
68+
using SPsiFunc = std::function<void(T*, T*, const int, const int)>;
7269

7370
int diag(
7471
const HPsiFunc& hpsi_func, // function void hpsi(T*, T*, const int, const int)

source/module_hsolver/hsolver_pw.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,17 +511,16 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
511511
};
512512

513513
/// wrap spsi into lambda function, Matrix \times blockvector
514-
/// spsi(X, SX, nrow, npw, nbands)
515-
/// nrow is leading dimension of spsi, npw is leading dimension of psi, nbands is number of vecs
514+
/// spsi(X, SX, ld, nvec)
515+
/// ld is leading dimension of psi and spsi
516516
auto spsi_func = [hm](const T* psi_in, T* spsi_out,
517-
const int ld_spsi, // Leading dimension of spsi. Dimension of SX: nbands * nrow.
518-
const int ld_psi, // Leading dimension of psi. Number of plane waves.
517+
const int ld_psi, // Leading dimension of psi and spsi.
519518
const int nvec // Number of vectors(bands)
520519
){
521520
ModuleBase::timer::tick("David", "spsi_func");
522521
// sPsi determines S=I or not by PARAM.globalv.use_uspp inside
523522
// sPsi(psi, spsi, nrow, npw, nbands)
524-
hm->sPsi(psi_in, spsi_out, ld_spsi, ld_psi, nvec);
523+
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
525524
ModuleBase::timer::tick("David", "spsi_func");
526525
};
527526

source/module_hsolver/test/diago_david_float_test.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ void lapackEigen(int &npw, std::vector<std::complex<float>> &hm, float * e, bool
4646
char tmp_c1 = 'V', tmp_c2 = 'U';
4747
cheev_(&tmp_c1, &tmp_c2, &npw, tmp.data(), &npw, e, work2, &lwork, rwork, &info);
4848
end = clock();
49-
if(info) std::cout << "ERROR: Lapack solver, info=" << info <<std::endl;
50-
if (outtime) std::cout<<"Lapack Run time: "<<(float)(end - start) / CLOCKS_PER_SEC<<" S"<<std::endl;
49+
if(info) { std::cout << "ERROR: Lapack solver, info=" << info <<std::endl;
50+
}
51+
if (outtime) { std::cout<<"Lapack Run time: "<<(float)(end - start) / CLOCKS_PER_SEC<<" S"<<std::endl;
52+
}
5153

5254
delete [] rwork;
5355
delete [] work2;
@@ -74,7 +76,8 @@ class DiagoDavPrepare
7476
//calculate eigenvalues by LAPACK;
7577
float* e_lapack = new float[npw];
7678
float* ev;
77-
if(mypnum == 0) lapackEigen(npw, DIAGOTEST::hmatrix_f, e_lapack,DETAILINFO);
79+
if(mypnum == 0) { lapackEigen(npw, DIAGOTEST::hmatrix_f, e_lapack,DETAILINFO);
80+
}
7881

7982
//do Diago_David::diag()
8083
float* en = new float[npw];
@@ -111,13 +114,13 @@ class DiagoDavPrepare
111114
const int ld_psi, const int nvec)
112115
{
113116
auto psi_iter_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nvec, ld_psi, nullptr);
114-
psi::Range bands_range(1, 0, 0, nvec-1);
117+
psi::Range bands_range(true, 0, 0, nvec-1);
115118
using hpsi_info = typename hamilt::Operator<std::complex<float>>::hpsi_info;
116119
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
117120
phm->ops->hPsi(info);
118121
};
119-
auto spsi_func = [phm](const std::complex<float>* psi_in, std::complex<float>* spsi_out,const int nrow, const int npw, const int nbands){
120-
phm->sPsi(psi_in, spsi_out, nrow, npw, nbands);
122+
auto spsi_func = [phm](const std::complex<float>* psi_in, std::complex<float>* spsi_out,const int ld_psi, const int nbands){
123+
phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands);
121124
};
122125
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, eps, maxiter);
123126

@@ -131,7 +134,8 @@ class DiagoDavPrepare
131134

132135
if(mypnum == 0)
133136
{
134-
if (DETAILINFO) std::cout<<"diag Run time: "<< use_time << std::endl;
137+
if (DETAILINFO) { std::cout<<"diag Run time: "<< use_time << std::endl;
138+
}
135139
for(int i=0;i<nband;i++)
136140
{
137141
EXPECT_NEAR(en[i],e_lapack[i],CONVTHRESHOLD);
@@ -148,8 +152,9 @@ class DiagoDavTest : public ::testing::TestWithParam<DiagoDavPrepare> {};
148152
TEST_P(DiagoDavTest,RandomHamilt)
149153
{
150154
DiagoDavPrepare ddp = GetParam();
151-
if (DETAILINFO&&ddp.mypnum==0) std::cout << "npw=" << ddp.npw << ", nband=" << ddp.nband << ", sparsity="
155+
if (DETAILINFO&&ddp.mypnum==0) { std::cout << "npw=" << ddp.npw << ", nband=" << ddp.nband << ", sparsity="
152156
<< ddp.sparsity << ", eps=" << ddp.eps << std::endl;
157+
}
153158

154159
HPsi<std::complex<float>> hpsi(ddp.nband, ddp.npw, ddp.sparsity);
155160
DIAGOTEST::hmatrix_f = hpsi.hamilt();
@@ -236,7 +241,8 @@ int main(int argc, char **argv)
236241

237242
testing::InitGoogleTest(&argc, argv);
238243
::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners();
239-
if (myrank != 0) delete listeners.Release(listeners.default_result_printer());
244+
if (myrank != 0) { delete listeners.Release(listeners.default_result_printer());
245+
}
240246

241247
int result = RUN_ALL_TESTS();
242248
if (myrank == 0 && result != 0)

source/module_hsolver/test/diago_david_real_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ class DiagoDavPrepare
118118
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
119119
phm->ops->hPsi(info);
120120
};
121-
auto spsi_func = [phm](const double* psi_in, double* spsi_out,const int nrow, const int npw, const int nbands){
122-
phm->sPsi(psi_in, spsi_out, nrow, npw, nbands);
121+
auto spsi_func = [phm](const double* psi_in, double* spsi_out,const int ld_psi, const int nbands){
122+
phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands);
123123
};
124124
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, eps, maxiter);
125125

source/module_hsolver/test/diago_david_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ class DiagoDavPrepare
118118
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
119119
phm->ops->hPsi(info);
120120
};
121-
auto spsi_func = [phm](const std::complex<double>* psi_in, std::complex<double>* spsi_out,const int nrow, const int npw, const int nbands){
122-
phm->sPsi(psi_in, spsi_out, nrow, npw, nbands);
121+
auto spsi_func = [phm](const std::complex<double>* psi_in, std::complex<double>* spsi_out,const int ld_psi, const int nbands){
122+
phm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands);
123123
};
124124
dav.diag(hpsi_func,spsi_func, ld_psi, phi.get_pointer(), en, eps, maxiter);
125125

source/module_hsolver/test/hsolver_pw_sup.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ DiagoDavid<T, Device>::~DiagoDavid() {
154154

155155
template <typename T, typename Device>
156156
int DiagoDavid<T, Device>::diag(const std::function<void(T*, T*, const int, const int)>& hpsi_func,
157-
const std::function<void(T*, T*, const int, const int, const int)>& spsi_func,
157+
const std::function<void(T*, T*, const int, const int)>& spsi_func,
158158
const int ld_psi,
159159
T *psi_in,
160160
Real* eigenvalue_in,

source/module_lr/hsolver_lrtd.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ namespace LR
9595
pHamilt->ops->hPsi(info);
9696
};
9797
auto spsi_func = [pHamilt](const T* psi_in, T* spsi_out,
98-
const int nrow, const int npw, const int nbands){
98+
const int ld_psi, const int nbands)
99+
{
99100
// sPsi determines S=I or not by PARAM.globalv.use_uspp inside
100-
pHamilt->sPsi(psi_in, spsi_out, nrow, npw, nbands);
101+
pHamilt->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nbands);
101102
};
102103

103104
const int& dim = psi_k1_dav.get_nbasis(); //equals to leading dimension here

0 commit comments

Comments
 (0)