Skip to content

Commit be47750

Browse files
committed
Feature: output math lib info
On -DINFO defined in cmake, the program will output parameters to std::cerr stream of zgemm and zhegvx.
1 parent 0cdc725 commit be47750

File tree

7 files changed

+238
-88
lines changed

7 files changed

+238
-88
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ option(USE_OPENMP " Enable OpenMP in abacus." ON)
2020
option(ENABLE_ASAN "Enable AddressSanitizer" OFF)
2121
option(BUILD_TESTING "Build ABACUS unit tests" OFF)
2222
option(GENERATE_TEST_REPORTS "Enable test report generation" OFF)
23+
option(INFO "Enable gathering of math library information" OFF)
2324

2425
set(ABACUS_BIN_NAME abacus)
2526
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/modules)
@@ -261,6 +262,12 @@ add_compile_definitions(
261262
TEST_EXX_RADIAL=1
262263
)
263264

265+
if(INFO)
266+
message(STATUS "Will gather math lib info.")
267+
add_compile_definitions(GATHER_INFO)
268+
# modifications on blas_connector and lapack_connector
269+
endif()
270+
264271
IF (BUILD_TESTING)
265272
set(CMAKE_CXX_STANDARD 14) # Required in orbital
266273
include(CTest)

source/module_base/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_library(
77
export.cpp
88
integral.cpp
99
inverse_matrix.cpp
10+
gather_math_lib_info.cpp
1011
global_file.cpp
1112
global_function.cpp
1213
global_function_ddotreal.cpp
@@ -16,7 +17,7 @@ add_library(
1617
math_polyint.cpp
1718
math_sphbes.cpp
1819
math_ylmreal.cpp
19-
math_bspline.cpp
20+
math_bspline.cpp
2021
math_chebyshev.cpp
2122
mathzone.cpp
2223
mathzone_add1.cpp

source/module_base/blas_connector.h

Lines changed: 87 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ extern "C"
1818
void daxpy_(const int *N, const double *alpha, const double *X, const int *incX, double *Y, const int *incY);
1919
void caxpy_(const int *N, const std::complex<float> *alpha, const std::complex<float> *X, const int *incX, std::complex<float> *Y, const int *incY);
2020
void zaxpy_(const int *N, const std::complex<double> *alpha, const std::complex<double> *X, const int *incX, std::complex<double> *Y, const int *incY);
21-
21+
2222
void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy);
23-
void zcopy_(long const *n, const std::complex<double> *a, int const *incx, std::complex<double> *b, int const *incy);
23+
void zcopy_(long const *n, const std::complex<double> *a, int const *incx, std::complex<double> *b, int const *incy);
2424

2525
//reason for passing results as argument instead of returning it:
2626
//see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/
27-
void zdotc_(std::complex<double> *result, const int *n, const std::complex<double> *zx,
27+
void zdotc_(std::complex<double> *result, const int *n, const std::complex<double> *zx,
2828
const int *incx, const std::complex<double> *zy, const int *incy);
2929
// Peize Lin add ?dot 2017-10-27, to compute d=x*y
3030
float sdot_(const int *N, const float *X, const int *incX, const float *Y, const int *incY);
@@ -36,36 +36,36 @@ extern "C"
3636
double dznrm2_( const int *n, const std::complex<double> *X, const int *incX );
3737

3838
// level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work.
39-
void dgemv_(const char *transa, const int *m, const int *n, const double *alpha, const double *a,
39+
void dgemv_(const char *transa, const int *m, const int *n, const double *alpha, const double *a,
4040
const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy);
41-
41+
4242
void zgemv_(const char *trans, const int *m, const int *n, const std::complex<double> *alpha,
4343
const std::complex<double> *a, const int *lda, const std::complex<double> *x, const int *incx,
4444
const std::complex<double> *beta, std::complex<double> *y, const int *incy);
4545

46-
void dsymv_(const char *uplo, const int *n,
47-
const double *alpha, const double *a, const int *lda,
48-
const double *x, const int *incx,
49-
const double *beta, double *y, const int *incy);
46+
void dsymv_(const char *uplo, const int *n,
47+
const double *alpha, const double *a, const int *lda,
48+
const double *x, const int *incx,
49+
const double *beta, double *y, const int *incy);
5050

5151
// A := alpha x * y.T + A
5252
void dger_(int *m, int *n, double *alpha, double *x, int *incx, double *y, int *incy, double *a, int *lda);
5353
void zgerc_(int *m, int *n, std::complex<double> *alpha,std::complex<double> *x, int *incx, std::complex<double> *y, int *incy,std::complex<double> *a, int *lda);
5454

5555
// level 3: matrix-matrix operations, O(n^2) data and O(n^3) work.
56-
56+
5757
// Peize Lin add ?gemm 2017-10-27, to compute C = a * A.? * B.? + b * C
5858
// A is general
5959
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
60-
const float *alpha, const float *a, const int *lda, const float *b, const int *ldb,
60+
const float *alpha, const float *a, const int *lda, const float *b, const int *ldb,
6161
const float *beta, float *c, const int *ldc);
6262
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
63-
const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
63+
const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
6464
const double *beta, double *c, const int *ldc);
6565
void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
66-
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
66+
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
6767
const std::complex<double> *beta, std::complex<double> *c, const int *ldc);
68-
68+
6969
//a is symmetric
7070
void dsymm_(const char *side, const char *uplo, const int *m, const int *n,
7171
const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
@@ -91,50 +91,50 @@ class BlasConnector
9191

9292
// Peize Lin add 2016-08-04
9393
// y=a*x+y
94-
static inline
94+
static inline
9595
void axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY)
9696
{
9797
saxpy_(&n, &alpha, X, &incX, Y, &incY);
98-
}
99-
static inline
98+
}
99+
static inline
100100
void axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY)
101101
{
102102
daxpy_(&n, &alpha, X, &incX, Y, &incY);
103-
}
104-
static inline
103+
}
104+
static inline
105105
void axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY)
106106
{
107107
caxpy_(&n, &alpha, X, &incX, Y, &incY);
108-
}
109-
static inline
108+
}
109+
static inline
110110
void axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY)
111111
{
112112
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
113-
}
114-
113+
}
114+
115115
// Peize Lin add 2016-08-04
116116
// x=a*x
117-
static inline
117+
static inline
118118
void scal( const int n, const float alpha, float *X, const int incX)
119119
{
120120
sscal_(&n, &alpha, X, &incX);
121-
}
122-
static inline
121+
}
122+
static inline
123123
void scal( const int n, const double alpha, double *X, const int incX)
124124
{
125125
dscal_(&n, &alpha, X, &incX);
126-
}
127-
static inline
126+
}
127+
static inline
128128
void scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX)
129129
{
130130
cscal_(&n, &alpha, X, &incX);
131-
}
132-
static inline
131+
}
132+
static inline
133133
void scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX)
134134
{
135135
zscal_(&n, &alpha, X, &incX);
136-
}
137-
136+
}
137+
138138
// Peize Lin add 2017-10-27
139139
// d=x*y
140140
static inline
@@ -149,32 +149,32 @@ class BlasConnector
149149
}
150150

151151
// Peize Lin add 2017-10-27, fix bug trans 2019-01-17
152-
// C = a * A.? * B.? + b * C
152+
// C = a * A.? * B.? + b * C
153153
static inline
154154
void gemm(const char transa, const char transb, const int m, const int n, const int k,
155-
const float alpha, const float *a, const int lda, const float *b, const int ldb,
155+
const float alpha, const float *a, const int lda, const float *b, const int ldb,
156156
const float beta, float *c, const int ldc)
157157
{
158158
sgemm_(&transb, &transa, &n, &m, &k,
159-
&alpha, b, &ldb, a, &lda,
159+
&alpha, b, &ldb, a, &lda,
160160
&beta, c, &ldc);
161161
}
162162
static inline
163163
void gemm(const char transa, const char transb, const int m, const int n, const int k,
164-
const double alpha, const double *a, const int lda, const double *b, const int ldb,
164+
const double alpha, const double *a, const int lda, const double *b, const int ldb,
165165
const double beta, double *c, const int ldc)
166166
{
167167
dgemm_(&transb, &transa, &n, &m, &k,
168-
&alpha, b, &ldb, a, &lda,
168+
&alpha, b, &ldb, a, &lda,
169169
&beta, c, &ldc);
170170
}
171171
static inline
172172
void gemm(const char transa, const char transb, const int m, const int n, const int k,
173-
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
173+
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
174174
const std::complex<double> beta, std::complex<double> *c, const int ldc)
175175
{
176176
zgemm_(&transb, &transa, &n, &m, &k,
177-
&alpha, b, &ldb, a, &lda,
177+
&alpha, b, &ldb, a, &lda,
178178
&beta, c, &ldc);
179179
}
180180

@@ -196,7 +196,7 @@ class BlasConnector
196196
return dznrm2_( &n, X, &incX );
197197
}
198198

199-
// copies a into b
199+
// copies a into b
200200
static inline
201201
void copy(const long n, const double *a, const int incx, double *b, const int incy)
202202
{
@@ -206,8 +206,52 @@ class BlasConnector
206206
void copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy)
207207
{
208208
zcopy_(&n, a, &incx, b, &incy);
209-
}
210-
209+
}
211210
};
212211

213-
#endif
212+
// If GATHER_INFO is defined, the original function is replaced with a "i" suffix,
213+
// preventing changes on the original code.
214+
// The real function call is at gather_math_lib_info.cpp
215+
#ifdef GATHER_INFO
216+
217+
#define zgemm_ zgemm_i
218+
void zgemm_i(const char *transa,
219+
const char *transb,
220+
const int *m,
221+
const int *n,
222+
const int *k,
223+
const std::complex<double> *alpha,
224+
const std::complex<double> *a,
225+
const int *lda,
226+
const std::complex<double> *b,
227+
const int *ldb,
228+
const std::complex<double> *beta,
229+
std::complex<double> *c,
230+
const int *ldc);
231+
232+
#define zaxpy_ zaxpy_i
233+
void zaxpy_i(const int *N,
234+
const std::complex<double> *alpha,
235+
const std::complex<double> *X,
236+
const int *incX,
237+
std::complex<double> *Y,
238+
const int *incY);
239+
240+
/*
241+
#define zgemv_ zgemv_i
242+
243+
void zgemv_i(const char *trans,
244+
const int *m,
245+
const int *n,
246+
const std::complex<double> *alpha,
247+
const std::complex<double> *a,
248+
const int *lda,
249+
const std::complex<double> *x,
250+
const int *incx,
251+
const std::complex<double> *beta,
252+
std::complex<double> *y,
253+
const int *incy);
254+
*/
255+
256+
#endif // GATHER_INFO
257+
#endif // BLAS_CONNECTOR_H
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// This file defines the math lib wrapper for output information before executing computations.
2+
#undef GATHER_INFO
3+
#include "module_base/blas_connector.h"
4+
#include "module_base/lapack_connector.h"
5+
6+
#include <iostream>
7+
8+
void zgemm_i(const char *transa,
9+
const char *transb,
10+
const int *m,
11+
const int *n,
12+
const int *k,
13+
const std::complex<double> *alpha,
14+
const std::complex<double> *a,
15+
const int *lda,
16+
const std::complex<double> *b,
17+
const int *ldb,
18+
const std::complex<double> *beta,
19+
std::complex<double> *c,
20+
const int *ldc)
21+
{
22+
std::cerr << std::defaultfloat << "zgemm " << *transa << " " << *transb << " " << *m << " " << *n << " " << *k
23+
<< " " << *alpha << " " << *lda << " " << *ldb << " " << *beta << " " << *ldc << std::endl;
24+
zgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
25+
}
26+
27+
void zaxpy_i(const int *N,
28+
const std::complex<double> *alpha,
29+
const std::complex<double> *X,
30+
const int *incX,
31+
std::complex<double> *Y,
32+
const int *incY)
33+
{
34+
// std::cout << "zaxpy " << *N << std::endl;
35+
// alpha is a coefficient
36+
// incX, incY is always 1
37+
zaxpy_(N, alpha, X, incX, Y, incY);
38+
}
39+
40+
void zhegvx_i(const int *itype,
41+
const char *jobz,
42+
const char *range,
43+
const char *uplo,
44+
const int *n,
45+
std::complex<double> *a,
46+
const int *lda,
47+
std::complex<double> *b,
48+
const int *ldb,
49+
const double *vl,
50+
const double *vu,
51+
const int *il,
52+
const int *iu,
53+
const double *abstol,
54+
const int *m,
55+
double *w,
56+
std::complex<double> *z,
57+
const int *ldz,
58+
std::complex<double> *work,
59+
const int *lwork,
60+
double *rwork,
61+
int *iwork,
62+
int *ifail,
63+
int *info)
64+
{
65+
std::cerr << std::defaultfloat << "zhegvx " << *itype << " " << *jobz << " " << *range << " " << *uplo << " " << *n
66+
<< " " << *lda << " " << *ldb << " " << *vl << " " << *vu << " " << *il << " " << *iu << " " << *abstol
67+
<< " " << *m << " " << *lwork << " " << *info << std::endl;
68+
zhegvx_(itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, work, lwork, rwork,
69+
iwork, ifail, info);
70+
}

0 commit comments

Comments
 (0)