Skip to content

Commit 22d12b3

Browse files
committed
Remove uncapsulted scalapack functions in dftu_occup.cpp
1 parent 06cffb7 commit 22d12b3

File tree

1 file changed

+47
-29
lines changed

1 file changed

+47
-29
lines changed

source/module_hamilt_lcao/module_dftu/dftu_occup.cpp

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,7 @@
55
#ifdef __LCAO
66
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
77
#endif
8-
9-
extern "C"
10-
{
11-
//I'm not sure what's happenig here, but the interface in scalapack_connecter.h
12-
//does not seem to work, so I'll use this one here
13-
void pzgemm_(
14-
const char *transa, const char *transb,
15-
const int *M, const int *N, const int *K,
16-
const std::complex<double> *alpha,
17-
const std::complex<double> *A, const int *IA, const int *JA, const int *DESCA,
18-
const std::complex<double> *B, const int *IB, const int *JB, const int *DESCB,
19-
const std::complex<double> *beta,
20-
std::complex<double> *C, const int *IC, const int *JC, const int *DESCC);
21-
22-
void pdgemm_(
23-
const char *transa, const char *transb,
24-
const int *M, const int *N, const int *K,
25-
const double *alpha,
26-
const double *A, const int *IA, const int *JA, const int *DESCA,
27-
const double *B, const int *IB, const int *JB, const int *DESCB,
28-
const double *beta,
29-
double *C, const int *IC, const int *JC, const int *DESCC);
30-
}
8+
#include "module_base/scalapack_connector.h"
319

3210
namespace ModuleDFTU
3311
{
@@ -161,7 +139,7 @@ void DFTU::cal_occup_m_k(const int iter,
161139

162140
//=================Part 1======================
163141
// call SCALAPACK routine to calculate the product of the S and density matrix
164-
const char transN = 'N', transT = 'T';
142+
char transN = 'N', transT = 'T';
165143
const int one_int = 1;
166144
const std::complex<double> beta(0.0,0.0), alpha(1.0,0.0);
167145

@@ -182,7 +160,27 @@ void DFTU::cal_occup_m_k(const int iter,
182160
}
183161

184162
#ifdef __MPI
185-
pzgemm_(&transN,
163+
ScalapackConnector::gemm(transN,
164+
transT,
165+
PARAM.globalv.nlocal,
166+
PARAM.globalv.nlocal,
167+
PARAM.globalv.nlocal,
168+
alpha,
169+
s_k_pointer,
170+
one_int,
171+
one_int,
172+
&this->paraV->desc[0],
173+
dm_k[ik].data(),
174+
//dm_k[ik].c,
175+
one_int,
176+
one_int,
177+
&this->paraV->desc[0],
178+
beta,
179+
srho.data(),
180+
one_int,
181+
one_int,
182+
&this->paraV->desc[0]);
183+
/*pzgemm_(&transN,
186184
&transT,
187185
&PARAM.globalv.nlocal,
188186
&PARAM.globalv.nlocal,
@@ -201,7 +199,7 @@ void DFTU::cal_occup_m_k(const int iter,
201199
&srho[0],
202200
&one_int,
203201
&one_int,
204-
this->paraV->desc);
202+
this->paraV->desc);*/
205203
#endif
206204

207205
const int spin = kv.isk[ik];
@@ -382,7 +380,7 @@ void DFTU::cal_occup_m_gamma(const int iter,
382380

383381
//=================Part 1======================
384382
// call PBLAS routine to calculate the product of the S and density matrix
385-
const char transN = 'N', transT = 'T';
383+
char transN = 'N', transT = 'T';
386384
const int one_int = 1;
387385
const double alpha = 1.0, beta = 0.0;
388386

@@ -393,7 +391,27 @@ void DFTU::cal_occup_m_gamma(const int iter,
393391
double* s_gamma_pointer = dynamic_cast<hamilt::HamiltLCAO<double, double>*>(p_ham)->getSk();
394392

395393
#ifdef __MPI
396-
pdgemm_(&transN,
394+
ScalapackConnector::gemm(transN,
395+
transT,
396+
PARAM.globalv.nlocal,
397+
PARAM.globalv.nlocal,
398+
PARAM.globalv.nlocal,
399+
alpha,
400+
s_gamma_pointer,
401+
one_int,
402+
one_int,
403+
&this->paraV->desc[0],
404+
dm_gamma[is].data(),
405+
//dm_gamma[is].c,
406+
one_int,
407+
one_int,
408+
&this->paraV->desc[0],
409+
beta,
410+
srho.data(),
411+
one_int,
412+
one_int,
413+
&this->paraV->desc[0]);
414+
/*pdgemm_(&transN,
397415
&transT,
398416
&PARAM.globalv.nlocal,
399417
&PARAM.globalv.nlocal,
@@ -412,7 +430,7 @@ void DFTU::cal_occup_m_gamma(const int iter,
412430
&srho[0],
413431
&one_int,
414432
&one_int,
415-
this->paraV->desc);
433+
this->paraV->desc);*/
416434
#endif
417435

418436
for (int it = 0; it < ucell.ntype; it++)

0 commit comments

Comments
 (0)