Skip to content

Commit 2d762d9

Browse files
committed
Add mtblas gemm kernel usage
1 parent 934e9e9 commit 2d762d9

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

source/module_base/blas_connector.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include "blas_connector.h"
22

3+
#ifdef __DSP
4+
#include "module_base/kernels/dsp/dsp_connector.h"
5+
#endif
6+
37
void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
48
{
59
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
@@ -83,7 +87,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
8387
sgemm_(&transb, &transa, &n, &m, &k,
8488
&alpha, b, &ldb, a, &lda,
8589
&beta, c, &ldc);
86-
}
90+
}
91+
#ifdef __DSP
92+
else if (device_type == base_device::AbacusDevice_t::DspDevice){
93+
sgemm_mt_(&transb, &transa, &n, &m, &k,
94+
&alpha, b, &ldb, a, &lda,
95+
&beta, c, &ldc);
96+
}
97+
#endif
8798
}
8899

89100
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -94,7 +105,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
94105
dgemm_(&transb, &transa, &n, &m, &k,
95106
&alpha, b, &ldb, a, &lda,
96107
&beta, c, &ldc);
97-
}
108+
}
109+
#ifdef __DSP
110+
else if (device_type == base_device::AbacusDevice_t::DspDevice){
111+
sgemm_mt_(&transb, &transa, &n, &m, &k,
112+
&alpha, b, &ldb, a, &lda,
113+
&beta, c, &ldc);
114+
}
115+
#endif
98116
}
99117

100118
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -105,7 +123,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
105123
cgemm_(&transb, &transa, &n, &m, &k,
106124
&alpha, b, &ldb, a, &lda,
107125
&beta, c, &ldc);
108-
}
126+
}
127+
#ifdef __DSP
128+
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
129+
cgemm_mt_(&transb, &transa, &n, &m, &k,
130+
&alpha, b, &ldb, a, &lda,
131+
&beta, c, &ldc);
132+
}
133+
#endif
109134
}
110135

111136
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -116,7 +141,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
116141
zgemm_(&transb, &transa, &n, &m, &k,
117142
&alpha, b, &ldb, a, &lda,
118143
&beta, c, &ldc);
119-
}
144+
}
145+
#ifdef __DSP
146+
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
147+
zgemm_mt_(&transb, &transa, &n, &m, &k,
148+
&alpha, b, &ldb, a, &lda,
149+
&beta, c, &ldc);
150+
}
151+
#endif
120152
}
121153

122154
void BlasConnector::gemv(const char trans, const int m, const int n,

0 commit comments

Comments
 (0)