Skip to content

Commit a902e66

Browse files
committed
Perf: accelerated davidson diagonalization method in PW code
1 parent 70510ad commit a902e66

File tree

2 files changed

+92
-10
lines changed

2 files changed

+92
-10
lines changed

source/module_hamilt/ks_pw/operator_pw.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef __OPERATORPW
22
#define __OPERATORPW
33
#include"module_hamilt/operator.h"
4+
#include "module_base/timer.h"
45

56
namespace hamilt
67
{
@@ -13,9 +14,9 @@ class OperatorPW : public Operator
1314
//run this->act function for the first operator and run all act() for other nodes in chain table
1415
virtual hpsi_info hPsi(const hpsi_info& input)const
1516
{
17+
ModuleBase::timer::tick("OperatorPW", "hPsi");
1618
std::tuple<const std::complex<double>*, int> psi_info = std::get<0>(input)->to_range(std::get<1>(input));
1719
int n_npwx = std::get<1>(psi_info);
18-
const int npwx = std::get<0>(input)->get_nbasis();
1920

2021
std::complex<double> *tmhpsi = this->get_hpsi(input);
2122
const std::complex<double> *tmpsi_in = std::get<0>(psi_info);
@@ -27,8 +28,10 @@ class OperatorPW : public Operator
2728
node->act(std::get<0>(input), n_npwx, tmpsi_in, tmhpsi);
2829
node = (OperatorPW*)(node->next_op);
2930
}
31+
32+
ModuleBase::timer::tick("OperatorPW", "hPsi");
3033

31-
return hpsi_info(this->hpsi, std::get<1>(input));
34+
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, n_npwx/std::get<0>(input)->npol));
3235
}
3336

3437
virtual void act

source/module_hsolver/diago_david.cpp

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ void DiagoDavid::cal_elem(const int &npw,
266266
{
267267
if (test_david == 1)
268268
ModuleBase::TITLE("DiagoDavid", "cal_elem");
269+
270+
if(notconv == 0) return;
269271
ModuleBase::timer::tick("DiagoDavid", "cal_elem");
270272

271273
// updat the reduced Hamiltonian
@@ -274,8 +276,43 @@ void DiagoDavid::cal_elem(const int &npw,
274276
// ModuleBase::GlobalFunc::ZEROS( hc.c+offset_h, notconv*hc.nr );
275277
// ModuleBase::GlobalFunc::ZEROS( sc.c+offset_s, notconv*sc.nr );
276278

277-
for (int i = nbase; i < nbase + notconv; i++)
279+
/*for (int i = nbase; i < nbase + notconv; i++)
278280
{
281+
char trans1 = 'C';
282+
char trans2 = 'N';
283+
const int nb_notc = i+1;
284+
const int one = 1;
285+
hc = transpose(hc, false);
286+
zgemm_(&trans1,
287+
&trans2,
288+
&one,
289+
&nb_notc,
290+
&npw,
291+
&ModuleBase::ONE,
292+
&basis(i , 0),
293+
&basis.get_nbasis(),
294+
hp.c,
295+
&hp.nc,
296+
&ModuleBase::ONE,
297+
&hc(0, i),
298+
&hc.nr);
299+
hc = transpose(hc, false);
300+
301+
sc = transpose(sc, false);
302+
zgemm_(&trans1,
303+
&trans2,
304+
&one,
305+
&nb_notc,
306+
&npw,
307+
&ModuleBase::ONE,
308+
&basis(i, 0),
309+
&basis.get_nbasis(),
310+
sp.c,
311+
&sp.nc,
312+
&ModuleBase::ONE,
313+
&sc(0, i),
314+
&sc.nr);
315+
sc = transpose(sc, false);
279316
for (int j = 0; j <= i; j++)
280317
{
281318
for (int ig = 0; ig < npw; ig++)
@@ -286,7 +323,43 @@ void DiagoDavid::cal_elem(const int &npw,
286323
// hc(j,i) = Diago_CG::ddot( npw, basis, j, hp, i );
287324
// sc(j,i) = Diago_CG::ddot( npw, basis, j, sp, i );
288325
}
289-
}
326+
std::cout<<__FILE__<<__LINE__<<" "<<i<<" "<<hc(i,0)<<" "<<hc(i,1)<<std::endl;
327+
}*/
328+
329+
char trans1 = 'C';
330+
char trans2 = 'N';
331+
const int nb_notc = (nbase + notconv);
332+
hc = transpose(hc, false);
333+
zgemm_(&trans1,
334+
&trans2,
335+
&notconv,
336+
&nb_notc,
337+
&npw,
338+
&ModuleBase::ONE,
339+
&basis(nbase, 0),
340+
&basis.get_nbasis(),
341+
hp.c,
342+
&hp.nc,
343+
&ModuleBase::ONE,
344+
hc.c + nbase,
345+
&hc.nr);
346+
hc = transpose(hc, false);
347+
348+
sc = transpose(sc, false);
349+
zgemm_(&trans1,
350+
&trans2,
351+
&notconv,
352+
&nb_notc,
353+
&npw,
354+
&ModuleBase::ONE,
355+
&basis(nbase, 0),
356+
&basis.get_nbasis(),
357+
sp.c,
358+
&sp.nc,
359+
&ModuleBase::ONE,
360+
sc.c + nbase,
361+
&sc.nr);
362+
sc = transpose(sc, false);
290363

291364
Parallel_Reduce::reduce_complex_double_pool(hc.c + offset_h, notconv * hc.nr);
292365
Parallel_Reduce::reduce_complex_double_pool(sc.c + offset_s, notconv * sc.nr);
@@ -521,18 +594,22 @@ void DiagoDavid::SchmitOrth(const int &npw,
521594
std::complex<double> *lagrange = new std::complex<double>[m + 1];
522595
ModuleBase::GlobalFunc::ZEROS(lagrange, m + 1);
523596

597+
const int one = 1;
524598
for (int j = 0; j < m; j++)
525599
{
526-
for (int ig = 0; ig < npw; ig++)
600+
const std::complex<double>* psi_p = &(psi(j, 0));
601+
zdotc_(&lagrange[j], &npw, psi_p, &one, spsi, &one);
602+
/*for (int ig = 0; ig < npw; ig++)
527603
{
528604
lagrange[j] += conj(psi(j, ig)) * spsi[ig];
529-
}
605+
}*/
530606
// lagrange[j] = Diago_CG::ddot( npw, psi, j, spsi );
531607
}
532-
for (int ig = 0; ig < npw; ig++)
608+
zdotc_(&lagrange[m], &npw, psi_m, &one, spsi, &one);
609+
/*for (int ig = 0; ig < npw; ig++)
533610
{
534611
lagrange[m] += conj(psi_m[ig]) * spsi[ig];
535-
}
612+
}*/
536613
// lagrange[m] = Diago_CG::ddot( npw, psi_m, spsi );
537614

538615
Parallel_Reduce::reduce_complex_double_pool(lagrange, m + 1);
@@ -545,10 +622,12 @@ void DiagoDavid::SchmitOrth(const int &npw,
545622

546623
for (int j = 0; j < m; j++)
547624
{
548-
for (int ig = 0; ig < npw; ig++)
625+
const std::complex<double> alpha = -1 * lagrange[j];
626+
zaxpy_(&npw, &alpha, &psi(j,0), &one, psi_m, &one);
627+
/*for (int ig = 0; ig < npw; ig++)
549628
{
550629
psi_m[ig] -= lagrange[j] * psi(j, ig);
551-
}
630+
}*/
552631
psi_norm -= (conj(lagrange[j]) * lagrange[j]).real();
553632
}
554633

0 commit comments

Comments
 (0)