Skip to content

Commit 6558216

Browse files
authored
Merge pull request #811 from dyzheng/pre_cg_dav_UT
fix : compiling conflict for USE_CUDA macro
2 parents 0c8a108 + 8c1bc7c commit 6558216

File tree

12 files changed

+94
-73
lines changed

12 files changed

+94
-73
lines changed

source/module_base/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@ add_library(
99
inverse_matrix.cpp
1010
global_file.cpp
1111
global_function.cpp
12+
global_function_ddotreal.cpp
1213
global_variable.cpp
1314
intarray.cpp
1415
math_integral.cpp
1516
math_polyint.cpp
1617
math_sphbes.cpp
1718
math_ylmreal.cpp
18-
math_bspline.cpp
19+
math_bspline.cpp
1920
mathzone.cpp
2021
mathzone_add1.cpp
2122
matrix.cpp

source/module_base/global_function.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,12 @@ static inline void FREE_MUL_PTR(T_element* v, const T_N_first N_first, const T_N
318318
v = nullptr;
319319
}
320320

321+
double ddot_real(
322+
const int & dim,
323+
const std::complex<double>* psi_L,
324+
const std::complex<double>* psi_R,
325+
const bool reduce = true) ;
326+
321327
}//namespace GlobalFunc
322328
}//namespace ModuleBase
323329

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "global_function.h"
2+
#include "blas_connector.h"
3+
#include "src_parallel/parallel_reduce.h"
4+
5+
namespace ModuleBase
6+
{
7+
namespace GlobalFunc
8+
{
9+
double ddot_real
10+
(
11+
const int &dim,
12+
const std::complex<double>* psi_L,
13+
const std::complex<double>* psi_R,
14+
const bool reduce
15+
)
16+
{
17+
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
18+
//qianrui modify 2021-3-14
19+
//Note that ddot_(2*dim,a,1,b,1) = REAL( zdotc_(dim,a,1,b,1) )
20+
int dim2=2*dim;
21+
double *pL,*pR;
22+
pL=(double *)psi_L;
23+
pR=(double *)psi_R;
24+
double result=BlasConnector::dot(dim2,pL,1,pR,1);
25+
if(reduce) Parallel_Reduce::reduce_double_pool( result );
26+
return result;
27+
//======================================================================
28+
/*std::complex<double> result(0,0);
29+
for (int i=0;i<dim;i++)
30+
{
31+
result += conj( psi_L[i] ) * psi_R[i];
32+
}
33+
Parallel_Reduce::reduce_complex_double_pool( result );
34+
return result.real();*/
35+
//>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
36+
}
37+
}
38+
}

source/src_pw/diago_cg.cpp

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void Diago_CG::diag
7373

7474
this->hpw->h_1psi(dim , phi_m, hphi, sphi);
7575

76-
e[m] = this->ddot_real(dim, phi_m, hphi );
76+
e[m] = ModuleBase::GlobalFunc::ddot_real(dim, phi_m, hphi );
7777

7878
int iter = 0;
7979
double gg_last = 0.0;
@@ -185,9 +185,9 @@ void Diago_CG::calculate_gradient(
185185

186186
// Update lambda !
187187
// (4) <psi|SPH|psi >
188-
const double eh = this->ddot_real( dim, spsi, g);
188+
const double eh = ModuleBase::GlobalFunc::ddot_real( dim, spsi, g);
189189
// (5) <psi|SPS|psi >
190-
const double es = this->ddot_real( dim, spsi, ppsi);
190+
const double es = ModuleBase::GlobalFunc::ddot_real( dim, spsi, ppsi);
191191
const double lambda = eh / es;
192192

193193
// Update g!
@@ -274,7 +274,7 @@ void Diago_CG::calculate_gamma_cg(
274274
// (1) Update gg_inter!
275275
// gg_inter = <g|psg>
276276
// Attention : the 'g' in psg is getted last time
277-
gg_inter = this->ddot_real( dim, g, psg);// b means before
277+
gg_inter = ModuleBase::GlobalFunc::ddot_real( dim, g, psg);// b means before
278278
}
279279

280280
// (2) Update for psg!
@@ -289,7 +289,7 @@ void Diago_CG::calculate_gamma_cg(
289289

290290
// (3) Update gg_now!
291291
// gg_now = < g|P|sg > = < g|psg >
292-
const double gg_now = this->ddot_real( dim, g, psg);
292+
const double gg_now = ModuleBase::GlobalFunc::ddot_real( dim, g, psg);
293293

294294
if (iter==0)
295295
{
@@ -344,12 +344,12 @@ bool Diago_CG::update_psi(
344344
if (test_cg==1) ModuleBase::TITLE("Diago_CG","update_psi");
345345
//ModuleBase::timer::tick("Diago_CG","update");
346346
this->hpw->h_1psi(dim, cg, hcg, scg);
347-
cg_norm = sqrt( this->ddot_real(dim, cg, scg) );
347+
cg_norm = sqrt( ModuleBase::GlobalFunc::ddot_real(dim, cg, scg) );
348348

349349
if (cg_norm < 1.0e-10 ) return 1;
350350

351-
const double a0 = this->ddot_real(dim, psi_m, hcg) * 2.0 / cg_norm;
352-
const double b0 = this->ddot_real(dim, cg, hcg) / ( cg_norm * cg_norm ) ;
351+
const double a0 = ModuleBase::GlobalFunc::ddot_real(dim, psi_m, hcg) * 2.0 / cg_norm;
352+
const double b0 = ModuleBase::GlobalFunc::ddot_real(dim, cg, hcg) / ( cg_norm * cg_norm ) ;
353353

354354
const double e0 = eigenvalue;
355355

@@ -448,7 +448,7 @@ void Diago_CG::schmit_orth
448448
//qianrui replace 2021-3-15
449449
char trans2='N';
450450
zgemv_(&trans2,&dim,&m,&ModuleBase::NEG_ONE,psi.c,&dmx,lagrange,&inc,&ModuleBase::ONE,psi_m,&inc);
451-
psi_norm -= ddot_real(m,lagrange,lagrange,false);
451+
psi_norm -= ModuleBase::GlobalFunc::ddot_real(m,lagrange,lagrange,false);
452452
//======================================================================
453453
/*for (int j = 0; j < m; j++)
454454
{
@@ -484,33 +484,3 @@ void Diago_CG::schmit_orth
484484
//ModuleBase::timer::tick("Diago_CG","schmit_orth");
485485
return ;
486486
}
487-
488-
489-
double Diago_CG::ddot_real
490-
(
491-
const int &dim,
492-
const std::complex<double>* psi_L,
493-
const std::complex<double>* psi_R,
494-
const bool reduce
495-
)
496-
{
497-
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
498-
//qianrui modify 2021-3-14
499-
//Note that ddot_(2*dim,a,1,b,1) = REAL( zdotc_(dim,a,1,b,1) )
500-
int dim2=2*dim;
501-
double *pL,*pR;
502-
pL=(double *)psi_L;
503-
pR=(double *)psi_R;
504-
double result=BlasConnector::dot(dim2,pL,1,pR,1);
505-
if(reduce) Parallel_Reduce::reduce_double_pool( result );
506-
return result;
507-
//======================================================================
508-
/*std::complex<double> result(0,0);
509-
for (int i=0;i<dim;i++)
510-
{
511-
result += conj( psi_L[i] ) * psi_R[i];
512-
}
513-
Parallel_Reduce::reduce_complex_double_pool( result );
514-
return result.real();*/
515-
//>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
516-
}

source/src_pw/diago_cg.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@
55
#include "../module_base/global_variable.h"
66
#include "../module_base/complexmatrix.h"
77

8-
#include "src_pw/hamilt_pw.h"
8+
#if ((defined __CUDA) || (defined __ROCM))
9+
10+
#ifdef __CUDA
11+
#include "hamilt_pw.cuh"
12+
#else
13+
#include "hamilt_pw_hip.h"
14+
#endif
15+
16+
#else
17+
#include "hamilt_pw.h"
18+
#endif
919

1020
class Diago_CG
1121
{
@@ -16,12 +26,6 @@ class Diago_CG
1626

1727
static int moved;
1828

19-
static double ddot_real(
20-
const int & dim,
21-
const std::complex<double>* psi_L,
22-
const std::complex<double>* psi_R,
23-
const bool reduce = true) ;
24-
2529
void diag(
2630
ModuleBase::ComplexMatrix &phi,
2731
double *e,

source/src_pw/diago_david.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "diago_david.h"
2-
#include "diago_cg.h"
32
#include "../src_parallel/parallel_reduce.h"
43
#include "../module_base/timer.h"
54
#include "module_base/constants.h"
@@ -255,8 +254,8 @@ void Diago_David::cal_grad
255254
ppsi[ig] = respsi[ig] / precondition[ig] ;
256255
}
257256
/*
258-
double ppsi_norm = Diago_CG::ddot_real( npw, ppsi, ppsi);
259-
double rpsi_norm = Diago_CG::ddot_real( npw, respsi, respsi);
257+
double ppsi_norm = ModuleBase::GlobalFunc::ddot_real( npw, ppsi, ppsi);
258+
double rpsi_norm = ModuleBase::GlobalFunc::ddot_real( npw, respsi, respsi);
260259
assert( rpsi_norm > 0.0 );
261260
assert( ppsi_norm > 0.0 );
262261
*/
@@ -498,7 +497,7 @@ void Diago_David::cal_err
498497
}
499498
}
500499

501-
err[m] = Diago_CG::ddot_real( npw, respsi, respsi );
500+
err[m] = ModuleBase::GlobalFunc::ddot_real( npw, respsi, respsi );
502501
err[m] = sqrt( err[m] );
503502
}
504503

source/src_pw/diago_david.cu

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,7 @@ void Diago_David_CUDA::cal_grad
289289
// ppsi[ig] = respsi[ig] / precondition[ig] ;
290290
// }
291291
kernel_precondition_david<double, double2><<<block, thread>>>(ppsi, respsi, npw, precondition);
292-
/*
293-
double ppsi_norm = Diago_CG::ddot_real( npw, ppsi, ppsi);
294-
double rpsi_norm = Diago_CG::ddot_real( npw, respsi, respsi);
295-
assert( rpsi_norm > 0.0 );
296-
assert( ppsi_norm > 0.0 );
297-
*/
292+
298293
this->SchmitOrth(npw, nbase+notconv, nbase+m, basis, ppsi, spsi);
299294

300295
GlobalC::hm.hpw.h_1psi(npw, ppsi, hpsi, spsi);
@@ -519,7 +514,7 @@ void Diago_David_CUDA::cal_err
519514
}
520515
}
521516

522-
err[m] = Diago_CG::ddot_real( npw, respsi, respsi );
517+
err[m] = ModuleBase::GlobalFunc::ddot_real( npw, respsi, respsi );
523518
err[m] = sqrt( err[m] );
524519
}
525520

source/src_pw/diago_david.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414
#include "../module_base/global_variable.h"
1515
#include "../module_base/complexmatrix.h"
1616

17-
#include "src_pw/hamilt_pw.h"
17+
#if ((defined __CUDA) || (defined __ROCM))
18+
19+
#ifdef __CUDA
20+
#include "hamilt_pw.cuh"
21+
#else
22+
#include "hamilt_pw_hip.h"
23+
#endif
24+
25+
#else
26+
#include "hamilt_pw.h"
27+
#endif
1828

1929
class Diago_David
2030
{

source/src_pw/hamilt.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "global.h"
22
#include "hamilt.h"
3-
// #include "diago_cg.h"
43
#include "diago_david.h"
54
#include "diago_cg.cuh"
65
#include "cufft.h"
@@ -285,7 +284,7 @@ void Hamilt::diagH_pw(
285284
}
286285
else if(GlobalV::KS_SOLVER=="dav")
287286
{
288-
Diago_David david;
287+
Diago_David david(&GlobalC::hm.hpw);
289288
if(GlobalV::NPOL==1)
290289
{
291290
david.diag(GlobalC::wf.evc[ik0], GlobalC::wf.ekb[ik], GlobalC::kv.ngk[ik],

source/src_pw/hamilt_hip.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "global.h"
22
#include "hamilt.h"
33
#include "hip/hip_runtime.h"
4-
// #include "diago_cg.h"
4+
//
55
#include "diago_cg_hip.h"
66
#include "diago_david.h"
77
#include "../module_base/timer.h"

0 commit comments

Comments
 (0)