1+ #include " blas_connector.h"
2+
3+ void BlasConnector::axpy ( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
4+ {
5+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
6+ saxpy_ (&n, &alpha, X, &incX, Y, &incY);
7+ }
8+ }
9+
10+ void BlasConnector::axpy ( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
11+ {
12+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
13+ daxpy_ (&n, &alpha, X, &incX, Y, &incY);
14+ }
15+ }
16+
17+ void BlasConnector::axpy ( const int n, const std::complex <float > alpha, const std::complex <float > *X, const int incX, std::complex <float > *Y, const int incY, base_device::AbacusDevice_t device_type)
18+ {
19+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
20+ caxpy_ (&n, &alpha, X, &incX, Y, &incY);
21+ }
22+ }
23+
24+ void BlasConnector::axpy ( const int n, const std::complex <double > alpha, const std::complex <double > *X, const int incX, std::complex <double > *Y, const int incY, base_device::AbacusDevice_t device_type)
25+ {
26+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
27+ zaxpy_ (&n, &alpha, X, &incX, Y, &incY);
28+ }
29+ }
30+
31+
32+ // x=a*x
33+ void BlasConnector::scal ( const int n, const float alpha, float *X, const int incX, base_device::AbacusDevice_t device_type)
34+ {
35+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
36+ sscal_ (&n, &alpha, X, &incX);
37+ }
38+ }
39+
40+ void BlasConnector::scal ( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
41+ {
42+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
43+ dscal_ (&n, &alpha, X, &incX);
44+ }
45+ }
46+
47+ void BlasConnector::scal ( const int n, const std::complex <float > alpha, std::complex <float > *X, const int incX, base_device::AbacusDevice_t device_type)
48+ {
49+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
50+ cscal_ (&n, &alpha, X, &incX);
51+ }
52+ }
53+
54+ void BlasConnector::scal ( const int n, const std::complex <double > alpha, std::complex <double > *X, const int incX, base_device::AbacusDevice_t device_type)
55+ {
56+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
57+ zscal_ (&n, &alpha, X, &incX);
58+ }
59+ }
60+
61+
62+ // d=x*y
63+ float BlasConnector::dot ( const int n, const float *X, const int incX, const float *Y, const int incY, base_device::AbacusDevice_t device_type)
64+ {
65+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
66+ return sdot_ (&n, X, &incX, Y, &incY);
67+ }
68+ }
69+
70+ double BlasConnector::dot ( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type)
71+ {
72+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
73+ return ddot_ (&n, X, &incX, Y, &incY);
74+ }
75+ }
76+
77+ // C = a * A.? * B.? + b * C
78+ void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
79+ const float alpha, const float *a, const int lda, const float *b, const int ldb,
80+ const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
81+ {
82+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
83+ sgemm_ (&transb, &transa, &n, &m, &k,
84+ &alpha, b, &ldb, a, &lda,
85+ &beta, c, &ldc);
86+ }
87+ }
88+
89+ void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
90+ const double alpha, const double *a, const int lda, const double *b, const int ldb,
91+ const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
92+ {
93+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
94+ dgemm_ (&transb, &transa, &n, &m, &k,
95+ &alpha, b, &ldb, a, &lda,
96+ &beta, c, &ldc);
97+ }
98+ }
99+
100+ void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
101+ const std::complex <float > alpha, const std::complex <float > *a, const int lda, const std::complex <float > *b, const int ldb,
102+ const std::complex <float > beta, std::complex <float > *c, const int ldc, base_device::AbacusDevice_t device_type)
103+ {
104+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
105+ cgemm_ (&transb, &transa, &n, &m, &k,
106+ &alpha, b, &ldb, a, &lda,
107+ &beta, c, &ldc);
108+ }
109+ }
110+
111+ void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
112+ const std::complex <double > alpha, const std::complex <double > *a, const int lda, const std::complex <double > *b, const int ldb,
113+ const std::complex <double > beta, std::complex <double > *c, const int ldc, base_device::AbacusDevice_t device_type)
114+ {
115+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
116+ zgemm_ (&transb, &transa, &n, &m, &k,
117+ &alpha, b, &ldb, a, &lda,
118+ &beta, c, &ldc);
119+ }
120+ }
121+
122+ void BlasConnector::gemv (const char trans, const int m, const int n,
123+ const double alpha, const double * A, const int lda, const double * X, const int incx,
124+ const double beta, double * Y, const int incy, base_device::AbacusDevice_t device_type)
125+ {
126+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
127+ dgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
128+ }
129+ }
130+
131+ void BlasConnector::gemv (const char trans, const int m, const int n,
132+ const std::complex <float > alpha, const std::complex <float > *A, const int lda, const std::complex <float > *X, const int incx,
133+ const std::complex <float > beta, std::complex <float > *Y, const int incy, base_device::AbacusDevice_t device_type)
134+ {
135+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
136+ cgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
137+ }
138+ }
139+
140+ void BlasConnector::gemv (const char trans, const int m, const int n,
141+ const std::complex <double > alpha, const std::complex <double > *A, const int lda, const std::complex <double > *X, const int incx,
142+ const std::complex <double > beta, std::complex <double > *Y, const int incy, base_device::AbacusDevice_t device_type)
143+ {
144+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
145+ zgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
146+ }
147+ }
148+
149+
150+ // out = ||x||_2
151+ float BlasConnector::nrm2 ( const int n, const float *X, const int incX, base_device::AbacusDevice_t device_type )
152+ {
153+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
154+ return snrm2_ ( &n, X, &incX );
155+ }
156+ }
157+
158+
159+ double BlasConnector::nrm2 ( const int n, const double *X, const int incX, base_device::AbacusDevice_t device_type )
160+ {
161+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
162+ return dnrm2_ ( &n, X, &incX );
163+ }
164+ }
165+
166+
167+ double BlasConnector::nrm2 ( const int n, const std::complex <double > *X, const int incX, base_device::AbacusDevice_t device_type )
168+ {
169+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
170+ return dznrm2_ ( &n, X, &incX );
171+ }
172+ }
173+
174+ // copies a into b
175+ void BlasConnector::copy (const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type)
176+ {
177+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
178+ dcopy_ (&n, a, &incx, b, &incy);
179+ }
180+ }
181+
182+ void BlasConnector::copy (const long n, const std::complex <double > *a, const int incx, std::complex <double > *b, const int incy, base_device::AbacusDevice_t device_type)
183+ {
184+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
185+ zcopy_ (&n, a, &incx, b, &incy);
186+ }
187+ }
0 commit comments