Skip to content

Commit 1dc2336

Browse files
committed
Fix&Perf: fixed nonlocal_pw.cpp compile error, improved davidson method
1 parent 0b05ace commit 1dc2336

File tree

2 files changed

+142
-71
lines changed

2 files changed

+142
-71
lines changed

source/module_hamilt/ks_pw/nonlocal_pw.cpp

Lines changed: 65 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Nonlocal<OperatorPW>::Nonlocal
3030
}
3131
}
3232

33+
template<>
3334
void Nonlocal<OperatorPW>::init(const int ik_in)
3435
{
3536
this->ik = ik_in;
@@ -45,72 +46,10 @@ void Nonlocal<OperatorPW>::init(const int ik_in)
4546
}
4647
}
4748

48-
template<>
49-
void Nonlocal<OperatorPW>::act
50-
(
51-
const psi::Psi<std::complex<double>> *psi_in,
52-
const int n_npwx,
53-
const std::complex<double>* tmpsi_in,
54-
std::complex<double>* tmhpsi
55-
)const
56-
{
57-
ModuleBase::timer::tick("Operator", "NonlocalPW");
58-
this->npw = psi_in->get_ngk(this->ik);
59-
this->max_npw = psi_in->get_nbasis() / psi_in->npol;
60-
this->npol = psi_in->npol;
61-
62-
if (this->ppcell->nkb > 0)
63-
{
64-
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
65-
// qianrui optimize 2021-3-31
66-
int nkb = this->ppcell->nkb;
67-
ModuleBase::ComplexMatrix becp(n_npwx, nkb, false);
68-
char transa = 'C';
69-
char transb = 'N';
70-
if (n_npwx == 1)
71-
{
72-
int inc = 1;
73-
zgemv_(&transa,
74-
&this->npw,
75-
&nkb,
76-
&ModuleBase::ONE,
77-
this->ppcell->vkb.c,
78-
&this->ppcell->vkb.nc,
79-
tmpsi_in,
80-
&inc,
81-
&ModuleBase::ZERO,
82-
becp.c,
83-
&inc);
84-
}
85-
else
86-
{
87-
int npm = n_npwx;
88-
zgemm_(&transa,
89-
&transb,
90-
&nkb,
91-
&npm,
92-
&this->npw,
93-
&ModuleBase::ONE,
94-
this->ppcell->vkb.c,
95-
&this->ppcell->vkb.nc,
96-
tmpsi_in,
97-
&this->max_npw,
98-
&ModuleBase::ZERO,
99-
becp.c,
100-
&nkb);
101-
}
102-
103-
Parallel_Reduce::reduce_complex_double_pool(becp.c, nkb * n_npwx);
104-
105-
this->add_nonlocal_pp(tmhpsi, becp.c, n_npwx);
106-
}
107-
ModuleBase::timer::tick("Operator", "NonlocalPW");
108-
return;
109-
}
110-
11149
//--------------------------------------------------------------------------
11250
// this function sum up each non-local pseudopotential located on each atom,
11351
//--------------------------------------------------------------------------
52+
template<>
11453
void Nonlocal<OperatorPW>::add_nonlocal_pp(std::complex<double> *hpsi_in, const std::complex<double> *becp, const int m) const
11554
{
11655
ModuleBase::timer::tick("Nonlocal", "add_nonlocal_pp");
@@ -231,4 +170,67 @@ void Nonlocal<OperatorPW>::add_nonlocal_pp(std::complex<double> *hpsi_in, const
231170
return;
232171
}
233172

173+
template<>
174+
void Nonlocal<OperatorPW>::act
175+
(
176+
const psi::Psi<std::complex<double>> *psi_in,
177+
const int n_npwx,
178+
const std::complex<double>* tmpsi_in,
179+
std::complex<double>* tmhpsi
180+
)const
181+
{
182+
ModuleBase::timer::tick("Operator", "NonlocalPW");
183+
this->npw = psi_in->get_ngk(this->ik);
184+
this->max_npw = psi_in->get_nbasis() / psi_in->npol;
185+
this->npol = psi_in->npol;
186+
187+
if (this->ppcell->nkb > 0)
188+
{
189+
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
190+
// qianrui optimize 2021-3-31
191+
int nkb = this->ppcell->nkb;
192+
ModuleBase::ComplexMatrix becp(n_npwx, nkb, false);
193+
char transa = 'C';
194+
char transb = 'N';
195+
if (n_npwx == 1)
196+
{
197+
int inc = 1;
198+
zgemv_(&transa,
199+
&this->npw,
200+
&nkb,
201+
&ModuleBase::ONE,
202+
this->ppcell->vkb.c,
203+
&this->ppcell->vkb.nc,
204+
tmpsi_in,
205+
&inc,
206+
&ModuleBase::ZERO,
207+
becp.c,
208+
&inc);
209+
}
210+
else
211+
{
212+
int npm = n_npwx;
213+
zgemm_(&transa,
214+
&transb,
215+
&nkb,
216+
&npm,
217+
&this->npw,
218+
&ModuleBase::ONE,
219+
this->ppcell->vkb.c,
220+
&this->ppcell->vkb.nc,
221+
tmpsi_in,
222+
&this->max_npw,
223+
&ModuleBase::ZERO,
224+
becp.c,
225+
&nkb);
226+
}
227+
228+
Parallel_Reduce::reduce_complex_double_pool(becp.c, nkb * n_npwx);
229+
230+
this->add_nonlocal_pp(tmhpsi, becp.c, n_npwx);
231+
}
232+
ModuleBase::timer::tick("Operator", "NonlocalPW");
233+
return;
234+
}
235+
234236
} // namespace hamilt

source/module_hsolver/diago_david.cpp

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,44 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
231231
// expand the reduced basis set with the new basis vectors P|R(psi)>...
232232
// in which psi are the last eigenvectors
233233
// we define |R(psi)> as (H-ES)*|Psi>, E = <psi|H|psi>/<psi|S|psi>
234+
std::vector<std::complex<double>> vc_ev_vector(nbase);
234235
for (int m = 0; m < notconv; m++)
235236
{
236-
ModuleBase::GlobalFunc::ZEROS(respsi, npw);
237+
for(int i = 0; i < nbase; i++)
238+
{
239+
vc_ev_vector[i] = vc(i, unconv[m]);
240+
}
241+
int inc = 1;
242+
char trans = 'N';
243+
zgemv_(&trans,
244+
&npw,
245+
&nbase,
246+
&ModuleBase::ONE,
247+
hp.c,
248+
&hp.nc,
249+
vc_ev_vector.data(),
250+
&inc,
251+
&ModuleBase::ZERO,
252+
respsi,
253+
&inc);
254+
255+
for(int i = 0; i < nbase; i++)
256+
{
257+
vc_ev_vector[i] *= -1 * eigenvalue[unconv[m]];
258+
}
259+
zgemv_(&trans,
260+
&npw,
261+
&nbase,
262+
&ModuleBase::ONE,
263+
sp.c,
264+
&sp.nc,
265+
vc_ev_vector.data(),
266+
&inc,
267+
&ModuleBase::ONE,
268+
respsi,
269+
&inc);
270+
271+
/*ModuleBase::GlobalFunc::ZEROS(respsi, npw);
237272
for (int i = 0; i < nbase; i++)
238273
{
239274
hpsi = &(hp(i, 0));
@@ -244,7 +279,8 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
244279
{
245280
respsi[ig] += vc_value * (hpsi[ig] - ev_value * spsi[ig]);
246281
}
247-
}
282+
}*/
283+
248284

249285
ppsi = &basis(nbase + m, 0);
250286
spsi = &sp(nbase + m, 0);
@@ -505,7 +541,37 @@ void DiagoDavid::refresh(const int &npw,
505541

506542
// update hp,sp
507543
basis.zero_out();
508-
for (int m = 0; m < nband; m++)
544+
char transa = 'N';
545+
char transb = 'T';
546+
zgemm_(&transa,
547+
&transb,
548+
&npw, // m: row of A,C
549+
&nband, // n: col of B,C
550+
&nbase, // k: col of A, row of B
551+
&ModuleBase::ONE, // alpha
552+
hp.c, // A
553+
&hp.nc, // LDA: if(N) max(1,m) if(T) max(1,k)
554+
vc.c, // B
555+
&vc.nc, // LDB: if(N) max(1,k) if(T) max(1,n)
556+
&ModuleBase::ZERO, // belta
557+
basis.get_pointer(), // C
558+
&basis.get_nbasis()); // LDC: if(N) max(1, m)
559+
560+
zgemm_(&transa,
561+
&transb,
562+
&npw, // m: row of A,C
563+
&nband, // n: col of B,C
564+
&nbase, // k: col of A, row of B
565+
&ModuleBase::ONE, // alpha
566+
sp.c, // A
567+
&sp.nc, // LDA: if(N) max(1,m) if(T) max(1,k)
568+
vc.c, // B
569+
&vc.nc, // LDB: if(N) max(1,k) if(T) max(1,n)
570+
&ModuleBase::ZERO, // belta
571+
&basis(nband, 0), // C
572+
&basis.get_nbasis()); // LDC: if(N) max(1, m)
573+
574+
/*for (int m = 0; m < nband; m++)
509575
{
510576
for (int j = 0; j < nbase; j++)
511577
{
@@ -515,23 +581,26 @@ void DiagoDavid::refresh(const int &npw,
515581
basis(m + nband, ig) += vc(j, m) * sp(j, ig);
516582
}
517583
}
518-
}
584+
}*/
519585

520-
for (int m = 0; m < nband; m++)
586+
ModuleBase::GlobalFunc::COPYARRAY(&basis(0, 0), &hp(0, 0), npw * nband);
587+
ModuleBase::GlobalFunc::COPYARRAY(&basis(nband, 0), &sp(0, 0), npw * nband);
588+
/*for (int m = 0; m < nband; m++)
521589
{
522590
for (int ig = 0; ig < npw; ig++)
523591
{
524592
hp(m, ig) = basis(m, ig);
525593
sp(m, ig) = basis(m + nband, ig);
526594
}
527-
}
595+
}*/
528596

529597
// update basis
530598
basis.zero_out();
531599
for (int m = 0; m < nband; m++)
532600
{
533-
for (int ig = 0; ig < npw; ig++)
534-
basis(m, ig) = psi(m, ig);
601+
ModuleBase::GlobalFunc::COPYARRAY(&psi(m, 0), &basis(m, 0), npw);
602+
/*for (int ig = 0; ig < npw; ig++)
603+
basis(m, ig) = psi(m, ig);*/
535604
}
536605

537606
// updata the reduced Hamiltonian

0 commit comments

Comments
 (0)