11#include " blas_connector.h"
22
3+ #ifdef __DSP
4+ #include " module_base/kernels/dsp/dsp_connector.h"
5+ #endif
6+
37void BlasConnector::axpy ( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
48{
59 if (device_type == base_device::AbacusDevice_t::CpuDevice) {
@@ -83,7 +87,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
8387 sgemm_ (&transb, &transa, &n, &m, &k,
8488 &alpha, b, &ldb, a, &lda,
8589 &beta, c, &ldc);
86- }
90+ }
91+ #ifdef __DSP
92+ else if (device_type == base_device::AbacusDevice_t::DspDevice){
93+ sgemm_mt_ (&transb, &transa, &n, &m, &k,
94+ &alpha, b, &ldb, a, &lda,
95+ &beta, c, &ldc);
96+ }
97+ #endif
8798}
8899
89100void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -94,7 +105,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
94105 dgemm_ (&transb, &transa, &n, &m, &k,
95106 &alpha, b, &ldb, a, &lda,
96107 &beta, c, &ldc);
97- }
108+ }
109+ #ifdef __DSP
110+ else if (device_type == base_device::AbacusDevice_t::DspDevice){
111+ sgemm_mt_ (&transb, &transa, &n, &m, &k,
112+ &alpha, b, &ldb, a, &lda,
113+ &beta, c, &ldc);
114+ }
115+ #endif
98116}
99117
100118void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -105,7 +123,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
105123 cgemm_ (&transb, &transa, &n, &m, &k,
106124 &alpha, b, &ldb, a, &lda,
107125 &beta, c, &ldc);
108- }
126+ }
127+ #ifdef __DSP
128+ else if (device_type == base_device::AbacusDevice_t::DspDevice) {
129+ cgemm_mt_ (&transb, &transa, &n, &m, &k,
130+ &alpha, b, &ldb, a, &lda,
131+ &beta, c, &ldc);
132+ }
133+ #endif
109134}
110135
111136void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -116,7 +141,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
116141 zgemm_ (&transb, &transa, &n, &m, &k,
117142 &alpha, b, &ldb, a, &lda,
118143 &beta, c, &ldc);
119- }
144+ }
145+ #ifdef __DSP
146+ else if (device_type == base_device::AbacusDevice_t::DspDevice) {
147+ zgemm_mt_ (&transb, &transa, &n, &m, &k,
148+ &alpha, b, &ldb, a, &lda,
149+ &beta, c, &ldc);
150+ }
151+ #endif
120152}
121153
122154void BlasConnector::gemv (const char trans, const int m, const int n,
0 commit comments