11#include " blas_connector.h"
22
3+ #ifdef __DSP
4+ #include " module_base/kernels/dsp/dsp_connector.h"
5+ #endif
6+
37void BlasConnector::axpy ( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
48{
59 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
@@ -64,13 +68,15 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
6468{
6569 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
6670 return sdot_ (&n, X, &incX, Y, &incY);
71+ return sdot_ (&n, X, &incX, Y, &incY);
6772}
6873}
6974
7075double BlasConnector::dot ( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type)
7176{
7277 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
7378 return ddot_ (&n, X, &incX, Y, &incY);
79+ return ddot_ (&n, X, &incX, Y, &incY);
7480}
7581}
7682
@@ -83,7 +89,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
8389 sgemm_ (&transb, &transa, &n, &m, &k,
8490 &alpha, b, &ldb, a, &lda,
8591 &beta, c, &ldc);
86- }
92+ }
93+ #ifdef __DSP
94+ else if (device_type == base_device::AbacusDevice_t::DspDevice){
95+ sgemm_mt_ (&transb, &transa, &n, &m, &k,
96+ &alpha, b, &ldb, a, &lda,
97+ &beta, c, &ldc);
98+ }
99+ #endif
87100}
88101
89102void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -94,7 +107,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
94107 dgemm_ (&transb, &transa, &n, &m, &k,
95108 &alpha, b, &ldb, a, &lda,
96109 &beta, c, &ldc);
97- }
110+ }
111+ #ifdef __DSP
112+ else if (device_type == base_device::AbacusDevice_t::DspDevice){
113+ dgemm_mt_ (&transb, &transa, &n, &m, &k,
114+ &alpha, b, &ldb, a, &lda,
115+ &beta, c, &ldc);
116+ }
117+ #endif
98118}
99119
100120void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -105,7 +125,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
105125 cgemm_ (&transb, &transa, &n, &m, &k,
106126 &alpha, b, &ldb, a, &lda,
107127 &beta, c, &ldc);
108- }
128+ }
129+ #ifdef __DSP
130+ else if (device_type == base_device::AbacusDevice_t::DspDevice) {
131+ cgemm_mt_ (&transb, &transa, &n, &m, &k,
132+ &alpha, b, &ldb, a, &lda,
133+ &beta, c, &ldc);
134+ }
135+ #endif
109136}
110137
111138void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -116,7 +143,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
116143 zgemm_ (&transb, &transa, &n, &m, &k,
117144 &alpha, b, &ldb, a, &lda,
118145 &beta, c, &ldc);
119- }
146+ }
147+ #ifdef __DSP
148+ else if (device_type == base_device::AbacusDevice_t::DspDevice) {
149+ zgemm_mt_ (&transb, &transa, &n, &m, &k,
150+ &alpha, b, &ldb, a, &lda,
151+ &beta, c, &ldc);
152+ }
153+ #endif
120154}
121155
122156void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -152,6 +186,7 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
152186{
153187 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
154188 return snrm2_ ( &n, X, &incX );
189+ return snrm2_ ( &n, X, &incX );
155190}
156191}
157192
@@ -160,6 +195,7 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
160195{
161196 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
162197 return dnrm2_ ( &n, X, &incX );
198+ return dnrm2_ ( &n, X, &incX );
163199}
164200}
165201
@@ -168,6 +204,7 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
168204{
169205 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
170206 return dznrm2_ ( &n, X, &incX );
207+ return dznrm2_ ( &n, X, &incX );
171208}
172209}
173210
0 commit comments