Skip to content

Commit 60c1f77

Browse files
Refactor: Seperate BLAS functions' declaration and implementation, and add a device_type flag for blas kernels (#5242)
* Seperate blas kernels' declaration and definition * Fix compilation bug * Move cblas link part to header file * Remove inline keyword * Fix test compilation * Fix * Fix * Fix library sequence * Optimize link usage * Fix test building error * Fix limxc build failing * Add device_type flag * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent cfffc78 commit 60c1f77

File tree

13 files changed

+296
-151
lines changed

13 files changed

+296
-151
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ OBJS_MAIN=main.o\
121121

122122
OBJS_BASE=abfs-vector3_order.o\
123123
assoc_laguerre.o\
124+
blas_connector.o\
124125
complexarray.o\
125126
complexmatrix.o\
126127
clebsch_gordan_coeff.o\

source/module_base/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_library(
1111
base
1212
OBJECT
1313
assoc_laguerre.cpp
14+
blas_connector.cpp
1415
clebsch_gordan_coeff.cpp
1516
complexarray.cpp
1617
complexmatrix.cpp
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#include "blas_connector.h"
2+
3+
void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
4+
{
5+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
6+
saxpy_(&n, &alpha, X, &incX, Y, &incY);
7+
}
8+
}
9+
10+
void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
11+
{
12+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
13+
daxpy_(&n, &alpha, X, &incX, Y, &incY);
14+
}
15+
}
16+
17+
void BlasConnector::axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY, base_device::AbacusDevice_t device_type)
18+
{
19+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
20+
caxpy_(&n, &alpha, X, &incX, Y, &incY);
21+
}
22+
}
23+
24+
void BlasConnector::axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY, base_device::AbacusDevice_t device_type)
25+
{
26+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
27+
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
28+
}
29+
}
30+
31+
32+
// x=a*x
33+
void BlasConnector::scal( const int n, const float alpha, float *X, const int incX, base_device::AbacusDevice_t device_type)
34+
{
35+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
36+
sscal_(&n, &alpha, X, &incX);
37+
}
38+
}
39+
40+
void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
41+
{
42+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
43+
dscal_(&n, &alpha, X, &incX);
44+
}
45+
}
46+
47+
void BlasConnector::scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX, base_device::AbacusDevice_t device_type)
48+
{
49+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
50+
cscal_(&n, &alpha, X, &incX);
51+
}
52+
}
53+
54+
void BlasConnector::scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX, base_device::AbacusDevice_t device_type)
55+
{
56+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
57+
zscal_(&n, &alpha, X, &incX);
58+
}
59+
}
60+
61+
62+
// d=x*y
63+
float BlasConnector::dot( const int n, const float *X, const int incX, const float *Y, const int incY, base_device::AbacusDevice_t device_type)
64+
{
65+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
66+
return sdot_(&n, X, &incX, Y, &incY);
67+
}
68+
}
69+
70+
double BlasConnector::dot( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type)
71+
{
72+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
73+
return ddot_(&n, X, &incX, Y, &incY);
74+
}
75+
}
76+
77+
// C = a * A.? * B.? + b * C
78+
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
79+
const float alpha, const float *a, const int lda, const float *b, const int ldb,
80+
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
81+
{
82+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
83+
sgemm_(&transb, &transa, &n, &m, &k,
84+
&alpha, b, &ldb, a, &lda,
85+
&beta, c, &ldc);
86+
}
87+
}
88+
89+
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
90+
const double alpha, const double *a, const int lda, const double *b, const int ldb,
91+
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
92+
{
93+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
94+
dgemm_(&transb, &transa, &n, &m, &k,
95+
&alpha, b, &ldb, a, &lda,
96+
&beta, c, &ldc);
97+
}
98+
}
99+
100+
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
101+
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
102+
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type)
103+
{
104+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
105+
cgemm_(&transb, &transa, &n, &m, &k,
106+
&alpha, b, &ldb, a, &lda,
107+
&beta, c, &ldc);
108+
}
109+
}
110+
111+
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
112+
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
113+
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type)
114+
{
115+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
116+
zgemm_(&transb, &transa, &n, &m, &k,
117+
&alpha, b, &ldb, a, &lda,
118+
&beta, c, &ldc);
119+
}
120+
}
121+
122+
void BlasConnector::gemv(const char trans, const int m, const int n,
123+
const double alpha, const double* A, const int lda, const double* X, const int incx,
124+
const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type)
125+
{
126+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
127+
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
128+
}
129+
}
130+
131+
void BlasConnector::gemv(const char trans, const int m, const int n,
132+
const std::complex<float> alpha, const std::complex<float> *A, const int lda, const std::complex<float> *X, const int incx,
133+
const std::complex<float> beta, std::complex<float> *Y, const int incy, base_device::AbacusDevice_t device_type)
134+
{
135+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
136+
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
137+
}
138+
}
139+
140+
void BlasConnector::gemv(const char trans, const int m, const int n,
141+
const std::complex<double> alpha, const std::complex<double> *A, const int lda, const std::complex<double> *X, const int incx,
142+
const std::complex<double> beta, std::complex<double> *Y, const int incy, base_device::AbacusDevice_t device_type)
143+
{
144+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
145+
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
146+
}
147+
}
148+
149+
150+
// out = ||x||_2
151+
float BlasConnector::nrm2( const int n, const float *X, const int incX, base_device::AbacusDevice_t device_type )
152+
{
153+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
154+
return snrm2_( &n, X, &incX );
155+
}
156+
}
157+
158+
159+
double BlasConnector::nrm2( const int n, const double *X, const int incX, base_device::AbacusDevice_t device_type )
160+
{
161+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
162+
return dnrm2_( &n, X, &incX );
163+
}
164+
}
165+
166+
167+
double BlasConnector::nrm2( const int n, const std::complex<double> *X, const int incX, base_device::AbacusDevice_t device_type )
168+
{
169+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
170+
return dznrm2_( &n, X, &incX );
171+
}
172+
}
173+
174+
// copies a into b
175+
void BlasConnector::copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type)
176+
{
177+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
178+
dcopy_(&n, a, &incx, b, &incy);
179+
}
180+
}
181+
182+
void BlasConnector::copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy, base_device::AbacusDevice_t device_type)
183+
{
184+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
185+
zcopy_(&n, a, &incx, b, &incy);
186+
}
187+
}

0 commit comments

Comments
 (0)