Skip to content

Commit 0b05ace

Browse files
committed
Perf: Davidson method and pw_gatherscatter.h
1 parent 8c0f5a3 commit 0b05ace

File tree

2 files changed

+74
-21
lines changed

2 files changed

+74
-21
lines changed

source/module_hsolver/diago_david.cpp

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,24 +68,27 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
6868
for (int m = 0; m < nband; m++)
6969
{
7070
// psi_m = psi(m)
71-
for (int ig = 0; ig < dim; ig++)
71+
ModuleBase::GlobalFunc::COPYARRAY(&psi(m, 0), psi_m.data(), dim);
72+
/*for (int ig = 0; ig < dim; ig++)
7273
{
7374
psi_m[ig] = psi(m, ig);
74-
}
75+
}*/
7576

7677
phm_in->sPsi(psi_m.data(), spsi.data(), (size_t)dim);
7778
this->SchmitOrth(dim, nband, m, basis, psi_m.data(), spsi.data());
7879
phm_in->sPsi(psi_m.data(), spsi.data(), (size_t)dim);
7980

8081
// basis(m) = psi_m, hp(m) = H |psi_m>, sp(m) = S |psi_m>
81-
std::complex<double>* sp_p = &sp(m, 0);
82+
ModuleBase::GlobalFunc::COPYARRAY(psi_m.data(), &basis(m, 0), dim);
83+
ModuleBase::GlobalFunc::COPYARRAY(spsi.data(), &sp(m, 0), dim);
84+
/*std::complex<double>* sp_p = &sp(m, 0);
8285
std::complex<double>* basis_p = &basis(m, 0);
8386
for (int ig = 0; ig < dim; ig++)
8487
{
8588
basis_p[ig] = psi_m[ig];
8689
//hp(m, ig) = hpsi[ig];
8790
sp_p[ig] = spsi[ig];
88-
}
91+
}*/
8992
}
9093
hp_info dav_hpsi_in(&basis, psi::Range(1, 0, 0, nband-1));
9194
auto hp_psi = std::get<0>(phm_in->ops->hPsi(dav_hpsi_in));
@@ -150,7 +153,22 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
150153

151154
// updata eigenvectors of Hamiltonian
152155
ModuleBase::GlobalFunc::ZEROS(psi.get_pointer(), psi.get_nbands() * psi.get_nbasis());
153-
for (int m = 0; m < nband; m++)
156+
char transa = 'N';
157+
char transb = 'T';
158+
zgemm_(&transa,
159+
&transb,
160+
&dim, // m: row of A,C
161+
&nband, // n: col of B,C
162+
&nbase, // k: col of A, row of B
163+
&ModuleBase::ONE, // alpha
164+
basis.get_pointer(), // A
165+
&basis.get_nbasis(), // LDA: if(N) max(1,m) if(T) max(1,k)
166+
vc.c, // B
167+
&nbase_x, // LDB: if(N) max(1,k) if(T) max(1,n)
168+
&ModuleBase::ZERO, // belta
169+
psi.get_pointer(), // C
170+
&psi.get_nbasis()); // LDC: if(N) max(1, m)
171+
/*for (int m = 0; m < nband; m++)
154172
{
155173
for (int j = 0; j < nbase; j++)
156174
{
@@ -159,7 +177,7 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
159177
psi(m, ig) += vc(j, m) * basis(j, ig);
160178
}
161179
}
162-
}
180+
}*/
163181

164182
if (!this->notconv || (dav_iter == DiagoIterAssist::PW_DIAG_NMAX))
165183
{
@@ -594,18 +612,31 @@ void DiagoDavid::SchmitOrth(const int &npw,
594612
std::complex<double> *lagrange = new std::complex<double>[m + 1];
595613
ModuleBase::GlobalFunc::ZEROS(lagrange, m + 1);
596614

597-
const int one = 1;
598-
for (int j = 0; j < m; j++)
615+
int inc = 1;
616+
int mp = m;
617+
char trans = 'C';
618+
zgemv_(&trans,
619+
&npw,
620+
&mp,
621+
&ModuleBase::ONE,
622+
psi.get_pointer(),
623+
&psi.get_nbasis(),
624+
spsi,
625+
&inc,
626+
&ModuleBase::ZERO,
627+
lagrange,
628+
&inc);
629+
/*for (int j = 0; j < m; j++)
599630
{
600631
const std::complex<double>* psi_p = &(psi(j, 0));
601632
zdotc_(&lagrange[j], &npw, psi_p, &one, spsi, &one);
602-
/*for (int ig = 0; ig < npw; ig++)
633+
for (int ig = 0; ig < npw; ig++)
603634
{
604635
lagrange[j] += conj(psi(j, ig)) * spsi[ig];
605-
}*/
636+
}
606637
// lagrange[j] = Diago_CG::ddot( npw, psi, j, spsi );
607-
}
608-
zdotc_(&lagrange[m], &npw, psi_m, &one, spsi, &one);
638+
}*/
639+
zdotc_(&lagrange[m], &npw, psi_m, &inc, spsi, &inc);
609640
/*for (int ig = 0; ig < npw; ig++)
610641
{
611642
lagrange[m] += conj(psi_m[ig]) * spsi[ig];
@@ -623,7 +654,7 @@ void DiagoDavid::SchmitOrth(const int &npw,
623654
for (int j = 0; j < m; j++)
624655
{
625656
const std::complex<double> alpha = -1 * lagrange[j];
626-
zaxpy_(&npw, &alpha, &psi(j,0), &one, psi_m, &one);
657+
zaxpy_(&npw, &alpha, &psi(j,0), &inc, psi_m, &inc);
627658
/*for (int ig = 0; ig < npw; ig++)
628659
{
629660
psi_m[ig] -= lagrange[j] * psi(j, ig);

source/module_pw/pw_gatherscatter.h

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,34 @@ void PW_Basis:: gatherp_scatters(std::complex<T> *in, std::complex<T> *out)
1616

1717
if(this->poolnproc == 1) //In this case nst=nstot, nz = nplane,
1818
{
19+
std::complex<T> * outp, *inp;
1920
for(int is = 0 ; is < this->nst ; ++is)
2021
{
2122
int ixy = this->istot2ixy[is];
2223
//int ixy = (ixy / fftny)*ny + ixy % fftny;
24+
outp = &out[is*nz];
25+
inp = &in[ixy*nz];
2326
for(int iz = 0 ; iz < this->nz ; ++iz)
2427
{
25-
out[is*nz+iz] = in[ixy*nz+iz];
28+
outp[iz] = inp[iz];
2629
}
2730
}
2831
ModuleBase::timer::tick(this->classname, "gatherp_scatters");
2932
return;
3033
}
3134
#ifdef __MPI
3235
//change (nplane fftnxy) to (nplane,nstot)
33-
// Hence, we can send them at one time.
36+
// Hence, we can send them at one time.
37+
std::complex<T> * outp, *inp;
3438
for (int istot = 0;istot < nstot; ++istot)
3539
{
3640
int ixy = this->istot2ixy[istot];
3741
//int ixy = (ixy / fftny)*ny + ixy % fftny;
42+
outp = &out[istot*nplane];
43+
inp = &in[ixy*nplane];
3844
for (int iz = 0; iz < nplane; ++iz)
3945
{
40-
out[istot*nplane+iz] = in[ixy*nplane+iz];
46+
outp[iz] = inp[iz];
4147
}
4248
}
4349

@@ -48,14 +54,19 @@ void PW_Basis:: gatherp_scatters(std::complex<T> *in, std::complex<T> *out)
4854
else if(typeid(T) == typeid(float))
4955
MPI_Alltoallv(out, numr, startr, MPI_COMPLEX, in, numg, startg, MPI_COMPLEX, this->pool_world);
5056
// change (nz,ns) to (numz[ip],ns, poolnproc)
57+
std::complex<T> * outp0, *inp0;
5158
for (int ip = 0; ip < this->poolnproc ;++ip)
5259
{
5360
int nzip = this->numz[ip];
61+
outp0 = &out[startz[ip]];
62+
inp0 = &in[startg[ip]];
5463
for (int is = 0; is < this->nst; ++is)
5564
{
65+
outp = &outp0[is * nz];
66+
inp = &inp0[is * nzip ];
5667
for (int izip = 0; izip < nzip; ++izip)
5768
{
58-
out[ is * nz + startz[ip] + izip] = in[startg[ip] + is*nzip + izip];
69+
outp[izip] = inp[izip];
5970
}
6071
}
6172
}
@@ -78,29 +89,38 @@ void PW_Basis:: gathers_scatterp(std::complex<T> *in, std::complex<T> *out)
7889
if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot,
7990
{
8091
ModuleBase::GlobalFunc::ZEROS(out, this->nrxx);
92+
std::complex<T> * outp, *inp;
8193
for(int is = 0 ; is < this->nst ; ++is)
8294
{
8395
int ixy = istot2ixy[is];
8496
//int ixy = (ixy / fftny)*ny + ixy % fftny;
97+
outp = &out[ixy*nz];
98+
inp = &in[is*nz];
8599
for(int iz = 0 ; iz < this->nz ; ++iz)
86100
{
87-
out[ixy*nz+iz] = in[is*nz+iz];
101+
outp[iz] = inp[iz];
88102
}
89103
}
90104
ModuleBase::timer::tick(this->classname, "gathers_scatterp");
91105
return;
92106
}
93107
#ifdef __MPI
94108
// change (nz,ns) to (numz[ip],ns, poolnproc)
95-
// Hence, we can send them at one time.
109+
// Hence, we can send them at one time.
110+
std::complex<T> * outp, *inp;
111+
std::complex<T> * outp0, *inp0;
96112
for (int ip = 0; ip < this->poolnproc ;++ip)
97113
{
98114
int nzip = this->numz[ip];
115+
outp0 = &out[startg[ip]];
116+
inp0 = &in[startz[ip]];
99117
for (int is = 0; is < this->nst; ++is)
100118
{
119+
outp = &outp0[is * nzip];
120+
inp = &inp0[is * nz ];
101121
for (int izip = 0; izip < nzip; ++izip)
102122
{
103-
out[startg[ip] + is*nzip + izip] = in[ is * nz + startz[ip] + izip];
123+
outp[izip] = inp[izip];
104124
}
105125
}
106126
}
@@ -117,9 +137,11 @@ void PW_Basis:: gathers_scatterp(std::complex<T> *in, std::complex<T> *out)
117137
{
118138
int ixy = this->istot2ixy[istot];
119139
//int ixy = (ixy / fftny)*ny + ixy % fftny;
140+
outp = &out[ixy * nplane];
141+
inp = &in[istot * nplane];
120142
for (int iz = 0; iz < nplane; ++iz)
121143
{
122-
out[ixy*nplane+iz] = in[istot*nplane+iz];
144+
outp[iz] = inp[iz];
123145
}
124146
}
125147

0 commit comments

Comments
 (0)