Skip to content

Commit c93cd1a

Browse files
authored
[Refactor] Remove my_math.hpp (#6002)
* Remove my_math.hpp * Remove the file
1 parent 0d9f762 commit c93cd1a

File tree

6 files changed

+246
-423
lines changed

6 files changed

+246
-423
lines changed

source/module_base/scalapack_connector.h

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ extern "C"
1212
int *desc,
1313
const int *m, const int *n, const int *mb, const int *nb, const int *irsrc, const int *icsrc,
1414
const int *ictxt, const int *lld, int *info);
15+
16+
void pddot_(int* n, double* dot, double* x, int* ix, int* jx, int* descx, int* incx,
17+
double* y, int* iy, int* jy, int* descy, int* incy);
18+
void pzdotc_(int* n, std::complex<double>* dot, std::complex<double>* x, int* ix, int* jx, int* descx, int* incx,
19+
std::complex<double>* y, int* iy, int* jy, int* descy, int* incy);
1520

1621
void pdpotrf_(char *uplo, int *n, double *a, int *ia, int *ja, int *desca, int *info);
1722
// void pzpotrf_(char *uplo, int *n, double _Complex *a, int *ia, int *ja, int *desca, int *info);
@@ -69,7 +74,10 @@ extern "C"
6974
void pztrmm_(char *side , char *uplo , char *transa , char *diag , int *m , int *n ,
7075
std::complex<double> *alpha , std::complex<double> *a , int *ia , int *ja , int *desca ,
7176
std::complex<double> *b , int *ib , int *jb , int *descb );
72-
77+
void pzhemm_(char* side , char* uplo , int* m , int* n ,
78+
std::complex<double>* alpha , std::complex<double>* a , int* ia , int* ja , int* desca ,
79+
std::complex<double>* b , int* ib , int* jb , int* descb ,
80+
std::complex<double>* beta , std::complex<double>* c , int* ic , int* jc , int* descc );
7381
void pzgetrf_(
7482
const int *M, const int *N,
7583
std::complex<double> *A, const int *IA, const int *JA, const int *DESCA,
@@ -200,6 +208,38 @@ class ScalapackConnector
200208
pzgeadd_(&transa, &m, &n, &alpha, a, &ia, &ja, desca, &beta, c, &ic, &jc, descc);
201209
}
202210

211+
static inline
212+
void dot(int n,
213+
double& dot,
214+
double* a,
215+
int ia,
216+
int ja,
217+
int inca,
218+
double* b,
219+
int ib,
220+
int jb,
221+
int incb,
222+
int* desc)
223+
{
224+
pddot_(&n, &dot, a, &ia, &ja, desc, &inca, b, &ib, &jb, desc, &incb);
225+
}
226+
227+
static inline
228+
void dot(int n,
229+
std::complex<double>& dotc,
230+
std::complex<double>* a,
231+
int ia,
232+
int ja,
233+
int inca,
234+
std::complex<double>* b,
235+
int ib,
236+
int jb,
237+
int incb,
238+
int* desc)
239+
{
240+
pzdotc_(&n, &dotc, a, &ia, &ja, desc, &inca, b, &ib, &jb, desc, &incb);
241+
}
242+
203243
static inline
204244
void gemm(
205245
const char transa, const char transb,
@@ -228,6 +268,85 @@ class ScalapackConnector
228268
B, &IB, &JB, DESCB, &beta, C, &IC, &JC, DESCC);
229269
}
230270

271+
static inline
272+
void gemm(char transa, char transb, int M, int N, int K,
273+
double alpha,
274+
double* A,
275+
double* B,
276+
double beta,
277+
double* C,
278+
int* DESC)
279+
{
280+
int isrc = 1;
281+
pdgemm_(&transa,
282+
&transb,
283+
&M,
284+
&N,
285+
&K,
286+
&alpha,
287+
A,
288+
&isrc,
289+
&isrc,
290+
DESC,
291+
B,
292+
&isrc,
293+
&isrc,
294+
DESC,
295+
&beta,
296+
C,
297+
&isrc,
298+
&isrc,
299+
DESC);
300+
}
301+
302+
static inline
303+
void gemm(char transa, char transb, int M, int N, int K,
304+
std::complex<double> alpha,
305+
std::complex<double>* A,
306+
std::complex<double>* B,
307+
std::complex<double> beta,
308+
std::complex<double>* C,
309+
int* DESC)
310+
{
311+
312+
int isrc = 1;
313+
pzgemm_(&transa,
314+
&transb,
315+
&M,
316+
&N,
317+
&K,
318+
&alpha,
319+
A,
320+
&isrc,
321+
&isrc,
322+
DESC,
323+
B,
324+
&isrc,
325+
&isrc,
326+
DESC,
327+
&beta,
328+
C,
329+
&isrc,
330+
&isrc,
331+
DESC);
332+
}
333+
334+
static inline
335+
void symm(char side,
336+
char uplo,
337+
int m,
338+
int n,
339+
double alpha,
340+
double* a,
341+
double* b,
342+
double beta,
343+
double* c,
344+
int* desc)
345+
{
346+
int isrc = 1;
347+
pdsymm_(&side, &uplo, &m, &n, &alpha, a, &isrc, &isrc, desc, b, &isrc, &isrc, desc, &beta, c, &isrc, &isrc, desc);
348+
}
349+
231350
static inline
232351
void getrf(
233352
const int M, const int N,
@@ -263,6 +382,88 @@ class ScalapackConnector
263382
{
264383
pztranu_(&m, &n, &alpha, a, &ia, &ja, desca, &beta, c, &ic, &jc, descc);
265384
}
385+
386+
static inline
387+
int potrf(char uplo, int na, double* U, int* desc)
388+
{
389+
int isrc = 1;
390+
int info;
391+
pdpotrf_(&uplo, &na, U, &isrc, &isrc, desc, &info);
392+
return info;
393+
}
394+
395+
static inline
396+
int potrf(char uplo, int na, std::complex<double>* U, int* desc)
397+
{
398+
int isrc = 1;
399+
int info;
400+
pzpotrf_(&uplo, &na, U, &isrc, &isrc, desc, &info);
401+
return info;
402+
}
403+
404+
static inline
405+
void trmm(char side,
406+
char uplo,
407+
char trans,
408+
char diag,
409+
int m,
410+
int n,
411+
double alpha,
412+
double* a,
413+
double* b,
414+
int* desc)
415+
{
416+
int isrc = 1;
417+
pdtrmm_(&side, &uplo, &trans, &diag, &m, &n, &alpha, a, &isrc, &isrc, desc, b, &isrc, &isrc, desc);
418+
}
419+
420+
static inline
421+
void trmm(char side,
422+
char uplo,
423+
char trans,
424+
char diag,
425+
int m,
426+
int n,
427+
std::complex<double> alpha,
428+
std::complex<double>* a,
429+
std::complex<double>* b,
430+
int* desc)
431+
{
432+
int isrc = 1;
433+
pztrmm_(&side, &uplo, &trans, &diag, &m, &n, &alpha, a, &isrc, &isrc, desc, b, &isrc, &isrc, desc);
434+
}
435+
436+
static inline
437+
void hemm(char side,
438+
char uplo,
439+
int na,
440+
std::complex<double> alpha,
441+
std::complex<double>* a,
442+
std::complex<double>* b,
443+
std::complex<double> beta,
444+
std::complex<double>* c,
445+
int* desc)
446+
{
447+
int isrc = 1;
448+
pzhemm_(&side,
449+
&uplo,
450+
&na,
451+
&na,
452+
&alpha,
453+
a,
454+
&isrc,
455+
&isrc,
456+
desc,
457+
b,
458+
&isrc,
459+
&isrc,
460+
desc,
461+
&beta,
462+
c,
463+
&isrc,
464+
&isrc,
465+
desc);
466+
}
266467
};
267468

268469
#endif // __MPI

source/module_hsolver/genelpa/elpa_new.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#include "elpa_new.h"
22

33
#include "elpa_solver.h"
4-
#include "my_math.hpp"
4+
extern "C"
5+
{
6+
#include "Cblacs.h"
7+
}
58
#include "utils.h"
69
#include <cfloat>
710
#include <complex>

source/module_hsolver/genelpa/elpa_new_complex.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "elpa_new.h"
1010
#include "elpa_solver.h"
1111

12-
#include "my_math.hpp"
12+
#include "module_base/scalapack_connector.h"
1313
#include "utils.h"
1414

1515
extern std::map<int, elpa_t> NEW_ELPA_HANDLE_POOL;
@@ -72,7 +72,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
7272
t=-1;
7373
timer(myid, "A*U^-1", "2.1a", t);
7474
}
75-
Cpzgemm('C', 'N', nFull, 1.0, A, B, 0.0, zwork.data(), desc);
75+
ScalapackConnector::gemm('C', 'N', nFull, nFull, nFull, 1.0, A, B, 0.0, zwork.data(), desc);
7676
if(loglevel>1)
7777
{
7878
timer(myid, "A*U^-1", "2.1a", t);
@@ -84,7 +84,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
8484
t=-1;
8585
timer(myid, "U^-T*(A*U^-1)", "2.2a", t);
8686
}
87-
Cpzgemm('C', 'N', nFull, 1.0, B, zwork.data(), 0.0, A, desc);
87+
ScalapackConnector::gemm('C', 'N', nFull, nFull, nFull, 1.0, B, zwork.data(), 0.0, A, desc);
8888
if(loglevel>1)
8989
{
9090
timer(myid, "U^-T*(A*U^-1)", "2.2a", t);
@@ -98,7 +98,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
9898
t=-1;
9999
timer(myid, "B*A^T", "2.1b", t);
100100
}
101-
Cpzgemm('N', 'C', nFull, 1.0, B, A, 0.0, zwork.data(), desc);
101+
ScalapackConnector::gemm('N', 'C', nFull, nFull, nFull, 1.0, B, A, 0.0, zwork.data(), desc);
102102
if(loglevel>1)
103103
{
104104
timer(myid, "B*A^T", "2.1b", t);
@@ -109,7 +109,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
109109
t=-1;
110110
timer(myid, "B*(B*A^T)^T", "2.2b", t);
111111
}
112-
Cpzgemm('N', 'C', nFull, 1.0, B, zwork.data(), 0.0, A, desc);
112+
ScalapackConnector::gemm('N', 'C', nFull, nFull, nFull, 1.0, B, zwork.data(), 0.0, A, desc);
113113
if(loglevel>1)
114114
{
115115
timer(myid, "B*(B*A^T)^T", "2.2b", t);
@@ -168,7 +168,7 @@ int ELPA_Solver::decomposeRightMatrix(std::complex<double>* B, double* EigenValu
168168
t=-1;
169169
timer(myid, "pzpotrf_", "1", t);
170170
}
171-
Cpzpotrf('U', nFull, B, desc);
171+
ScalapackConnector::potrf('U', nFull, B, desc);
172172
if(loglevel>1)
173173
{
174174
timer(myid, "pzpotrf_", "1", t);
@@ -214,7 +214,7 @@ int ELPA_Solver::decomposeRightMatrix(std::complex<double>* B, double* EigenValu
214214
t=-1;
215215
timer(myid, "pzpotrf_", "2", t);
216216
}
217-
Cpzpotrf('U', nFull, B, desc);
217+
ScalapackConnector::potrf('U', nFull, B, desc);
218218
if(loglevel>1)
219219
{
220220
timer(myid, "pzpotrf_", "2", t);
@@ -290,7 +290,7 @@ int ELPA_Solver::decomposeRightMatrix(std::complex<double>* B, double* EigenValu
290290
t=-1;
291291
timer(myid, "qevq=qev*q^T", "2", t);
292292
}
293-
Cpzgemm('N', 'C', nFull, 1.0, zwork.data(), EigenVector, 0.0, B, desc);
293+
ScalapackConnector::gemm('N', 'C', nFull, nFull, nFull, 1.0, zwork.data(), EigenVector, 0.0, B, desc);
294294
if(loglevel>1)
295295
{
296296
timer(myid, "qevq=qev*q^T", "2", t);
@@ -310,7 +310,7 @@ int ELPA_Solver::composeEigenVector(int DecomposedState, std::complex<double>* B
310310
t=-1;
311311
timer(myid, "Cpztrmm", "1", t);
312312
}
313-
Cpztrmm('L', 'U', 'N', 'N', nFull, nev, 1.0, B, EigenVector, desc);
313+
ScalapackConnector::trmm('L', 'U', 'N', 'N', nFull, nev, 1.0, B, EigenVector, desc);
314314
if(loglevel>1)
315315
{
316316
timer(myid, "Cpztrmm", "1", t);
@@ -322,7 +322,7 @@ int ELPA_Solver::composeEigenVector(int DecomposedState, std::complex<double>* B
322322
t=-1;
323323
timer(myid, "Cpzgemm", "1", t);
324324
}
325-
Cpzgemm('C', 'N', nFull, nev, nFull, 1.0, B, zwork.data(), 0.0, EigenVector, desc);
325+
ScalapackConnector::gemm('C', 'N', nFull, nev, nFull, 1.0, B, zwork.data(), 0.0, EigenVector, desc);
326326
if(loglevel>1)
327327
{
328328
timer(myid, "Cpzgemm", "1", t);
@@ -368,19 +368,19 @@ void ELPA_Solver::verify(std::complex<double>* A, double* EigenValue, std::compl
368368
}
369369

370370
// R=V*D
371-
Cpzhemm('R', 'U', nFull, 1.0, D, V, 0.0, R, desc);
371+
ScalapackConnector::hemm('R', 'U', nFull, 1.0, D, V, 0.0, R, desc);
372372
if(loglevel>2) saveMatrix("VD.dat", nFull, R, desc, cblacs_ctxt);
373373
// R=A*V-V*D=A*V-R
374-
Cpzhemm('L', 'U', nFull, 1.0, A, V, -1.0, R, desc);
374+
ScalapackConnector::hemm('L', 'U', nFull, 1.0, A, V, -1.0, R, desc);
375375
if(loglevel>2) saveMatrix("AV-VD.dat", nFull, R, desc, cblacs_ctxt);
376376
// calculate the maximum and mean value of sum_i{R(:,i)*R(:,i)}
377377
double sumError=0;
378378
maxError=0;
379379
for(int i=1; i<=nev; ++i)
380380
{
381381
std::complex<double> E;
382-
Cpzdotc(nFull, E, R, 1, i, 1,
383-
R, 1, i, 1, desc);
382+
ScalapackConnector::dot(nFull, E, R, 1, i, 1,
383+
R, 1, i, 1, desc);
384384
double abs_E=std::abs(E);
385385
sumError+=abs_E;
386386
maxError=std::max(maxError, abs_E);
@@ -427,22 +427,22 @@ void ELPA_Solver::verify(std::complex<double>* A, std::complex<double>* B,
427427
}
428428

429429
// zwork=B*V
430-
Cpzhemm('L', 'U', nFull, 1.0, B, V, 0.0, zwork.data(), desc);
430+
ScalapackConnector::hemm('L', 'U', nFull, 1.0, B, V, 0.0, zwork.data(), desc);
431431
if(loglevel>2) saveMatrix("BV.dat", nFull, zwork.data(), desc, cblacs_ctxt);
432432
// R=B*V*D=zwork*D
433-
Cpzhemm('R', 'U', nFull, 1.0, D, zwork.data(), 0.0, R, desc);
433+
ScalapackConnector::hemm('L', 'U', nFull, 1.0, B, V, 0.0, zwork.data(), desc);
434434
if(loglevel>2) saveMatrix("BVD.dat", nFull, R, desc, cblacs_ctxt);
435435
// R=A*V-B*V*D=A*V-R
436-
Cpzhemm('L', 'U', nFull, 1.0, A, V, -1.0, R, desc);
436+
ScalapackConnector::hemm('L', 'U', nFull, 1.0, B, V, 0.0, zwork.data(), desc);
437437
if(loglevel>2) saveMatrix("AV-BVD.dat", nFull, R, desc, cblacs_ctxt);
438438
// calculate the maximum and mean value of sum_i{R(:,i)*R(:,i)}
439439
double sumError=0;
440440
maxError=0;
441441
for(int i=1; i<=nev; ++i)
442442
{
443443
std::complex<double> E;
444-
Cpzdotc(nFull, E, R, 1, i, 1,
445-
R, 1, i, 1, desc);
444+
ScalapackConnector::dot(nFull, E, R, 1, i, 1,
445+
R, 1, i, 1, desc);
446446
double abs_E=std::abs(E);
447447
sumError+=abs_E;
448448
maxError=std::max(maxError, abs_E);

0 commit comments

Comments
 (0)