Skip to content

Commit 126b36a

Browse files
authored
Merge pull request #805 from dyzheng/pre_cg_dav_UT
refactor: removed GlobalC and gloabl.h dependency in diago_cg.cpp and diago_david.cpp
2 parents 9ad8356 + 1813dc1 commit 126b36a

File tree

5 files changed

+33
-19
lines changed

5 files changed

+33
-19
lines changed

source/src_pw/diago_cg.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#include "diago_cg.h"
2-
#include "global.h"
32
#include "../src_parallel/parallel_reduce.h"
43
#include "../module_base/timer.h"
4+
#include "module_base/constants.h"
5+
#include "module_base/blas_connector.h"
56

67
int Diago_CG::moved = 0;
78

89

9-
Diago_CG::Diago_CG()
10+
Diago_CG::Diago_CG(Hamilt_PW* phamilt)
1011
{
12+
this->hpw = phamilt;
1113
test_cg=0;
1214
}
1315
Diago_CG::~Diago_CG() {}
@@ -66,10 +68,10 @@ void Diago_CG::diag
6668
if (test_cg>2) GlobalV::ofs_running << "Diagonal Band : " << m << std::endl;
6769
for (int i=0; i<dim; i++) phi_m[i] = phi(m, i);
6870

69-
GlobalC::hm.hpw.s_1psi(dim, phi_m, sphi); // sphi = S|psi(m)>
71+
this->hpw->s_1psi(dim, phi_m, sphi); // sphi = S|psi(m)>
7072
this->schmit_orth(dim, dmx, m, phi, sphi, phi_m);
7173

72-
GlobalC::hm.hpw.h_1psi(dim , phi_m, hphi, sphi);
74+
this->hpw->h_1psi(dim , phi_m, hphi, sphi);
7375

7476
e[m] = this->ddot_real(dim, phi_m, hphi );
7577

@@ -210,7 +212,7 @@ void Diago_CG::orthogonal_gradient( const int &dim, const int &dmx,
210212
if (test_cg==1) ModuleBase::TITLE("Diago_CG","orthogonal_gradient");
211213
//ModuleBase::timer::tick("Diago_CG","orth_grad");
212214

213-
GlobalC::hm.hpw.s_1psi(dim , g, sg);
215+
this->hpw->s_1psi(dim , g, sg);
214216
int inc=1;
215217
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
216218
//qianrui replace 2021-3-15
@@ -341,7 +343,7 @@ bool Diago_CG::update_psi(
341343
{
342344
if (test_cg==1) ModuleBase::TITLE("Diago_CG","update_psi");
343345
//ModuleBase::timer::tick("Diago_CG","update");
344-
GlobalC::hm.hpw.h_1psi(dim, cg, hcg, scg);
346+
this->hpw->h_1psi(dim, cg, hcg, scg);
345347
cg_norm = sqrt( this->ddot_real(dim, cg, scg) );
346348

347349
if (cg_norm < 1.0e-10 ) return 1;
@@ -476,7 +478,7 @@ void Diago_CG::schmit_orth
476478
{
477479
psi_m[ig] /= psi_norm;
478480
}
479-
GlobalC::hm.hpw.s_1psi(dim, psi_m, sphi); // sphi = S|psi(m)>
481+
this->hpw->s_1psi(dim, psi_m, sphi); // sphi = S|psi(m)>
480482

481483
delete [] lagrange ;
482484
//ModuleBase::timer::tick("Diago_CG","schmit_orth");

source/src_pw/diago_cg.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
#include "../module_base/global_variable.h"
66
#include "../module_base/complexmatrix.h"
77

8+
#include "src_pw/hamilt_pw.h"
9+
810
class Diago_CG
911
{
1012
public:
1113

12-
Diago_CG();
14+
Diago_CG(Hamilt_PW* phamilt);
1315
~Diago_CG();
1416

1517
static int moved;
@@ -33,7 +35,7 @@ class Diago_CG
3335
int &notconv,
3436
double &avg_iter);
3537

36-
static void schmit_orth(
38+
void schmit_orth(
3739
const int &dim,
3840
const int &dmx,
3941
const int &end,
@@ -44,6 +46,9 @@ class Diago_CG
4446

4547
private:
4648

49+
/// temp operator pointer
50+
Hamilt_PW* hpw;
51+
4752
int test_cg;
4853

4954
void calculate_gradient(

source/src_pw/diago_david.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#include "diago_david.h"
22
#include "diago_cg.h"
3-
#include "global.h"
43
#include "../src_parallel/parallel_reduce.h"
54
#include "../module_base/timer.h"
5+
#include "module_base/constants.h"
6+
#include "module_base/blas_connector.h"
7+
#include "module_base/lapack_connector.h"
68

7-
Diago_David::Diago_David()
9+
Diago_David::Diago_David(Hamilt_PW* phamilt)
810
{
11+
this->hpw = phamilt;
912
test_david = 2;
1013
// 1: check which function is called and which step is executed
1114
// 2: check the eigenvalues of the result of each iteration
@@ -86,7 +89,7 @@ void Diago_David::diag
8689

8790
this->SchmitOrth(npw, nband, m, basis, psi_m, spsi);
8891

89-
GlobalC::hm.hpw.h_1psi(npw, psi_m, hpsi, spsi);
92+
this->hpw->h_1psi(npw, psi_m, hpsi, spsi);
9093

9194
// basis(m) = psi_m, hp(m) = H |psi_m>, sp(m) = S |psi_m>
9295
for ( int ig = 0; ig < npw; ig++ )
@@ -259,7 +262,7 @@ void Diago_David::cal_grad
259262
*/
260263
this->SchmitOrth(npw, nbase+notconv, nbase+m, basis, ppsi, spsi);
261264

262-
GlobalC::hm.hpw.h_1psi(npw, ppsi, hpsi, spsi);
265+
this->hpw->h_1psi(npw, ppsi, hpsi, spsi);
263266

264267
for ( int ig = 0; ig < npw; ig++ )
265268
{
@@ -534,7 +537,7 @@ void Diago_David::SchmitOrth
534537
assert(m >= 0);
535538
assert(m < n_band);
536539

537-
GlobalC::hm.hpw.s_1psi(npw, psi_m, spsi);
540+
this->hpw->s_1psi(npw, psi_m, spsi);
538541

539542
std::complex<double>* lagrange = new std::complex<double>[m+1];
540543
ModuleBase::GlobalFunc::ZEROS( lagrange, m+1 );
@@ -589,7 +592,7 @@ void Diago_David::SchmitOrth
589592
}
590593
}
591594

592-
GlobalC::hm.hpw.s_1psi(npw, psi_m, spsi);
595+
this->hpw->s_1psi(npw, psi_m, spsi);
593596

594597
delete[] lagrange;
595598
ModuleBase::timer::tick("Diago_David","SchmitOrth");

source/src_pw/diago_david.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
#include "../module_base/global_variable.h"
1515
#include "../module_base/complexmatrix.h"
1616

17+
#include "src_pw/hamilt_pw.h"
18+
1719
class Diago_David
1820
{
1921
public:
2022

21-
Diago_David();
23+
Diago_David(Hamilt_PW* phamilt);
2224
~Diago_David();
2325

24-
static void SchmitOrth(
26+
void SchmitOrth(
2527
const int& npw,
2628
const int n_band,
2729
const int m,
@@ -52,6 +54,8 @@ class Diago_David
5254

5355
private:
5456

57+
Hamilt_PW* hpw;
58+
5559
int test_david;
5660

5761
void cal_grad(

source/src_pw/hamilt.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void Hamilt::diagH_pw(
8282

8383
avg_iter += 1.0;
8484
}
85-
Diago_CG cg;
85+
Diago_CG cg(&GlobalC::hm.hpw);
8686

8787
bool reorder = true;
8888

@@ -106,7 +106,7 @@ void Hamilt::diagH_pw(
106106
}
107107
else if(GlobalV::KS_SOLVER=="dav")
108108
{
109-
Diago_David david;
109+
Diago_David david(&GlobalC::hm.hpw);
110110
if(GlobalV::NPOL==1)
111111
{
112112
david.diag(GlobalC::wf.evc[ik0], GlobalC::wf.ekb[ik], GlobalC::kv.ngk[ik],

0 commit comments

Comments
 (0)