44#ifdef __DSP
55#include " source_base/kernels/dsp/dsp_connector.h"
66#include " source_base/global_variable.h"
7+ #include " source_io/module_parameter/parameter.h"
78#endif
89
910#ifdef __CUDA
@@ -30,7 +31,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
3031 else if (device_type == base_device::AbacusDevice_t::DspDevice){
3132 mtfunc::sgemm_mth_ (&transb, &transa, &n, &m, &k,
3233 &alpha, b, &ldb, a, &lda,
33- &beta, c, &ldc, GlobalV::MY_RANK % 4 );
34+ &beta, c, &ldc, GlobalV::MY_RANK % PARAM. inp . dsp_count );
3435 }
3536#endif
3637#ifdef __CUDA
@@ -67,7 +68,7 @@ void BlasConnector::gemm(const char transa,
6768#ifdef __DSP
6869 else if (device_type == base_device::AbacusDevice_t::DspDevice)
6970 {
70- mtfunc::dgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
71+ mtfunc::dgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM. inp . dsp_count );
7172 }
7273#endif
7374 else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -106,7 +107,7 @@ void BlasConnector::gemm(const char transa,
106107#ifdef __DSP
107108 else if (device_type == base_device::AbacusDevice_t::DspDevice)
108109 {
109- mtfunc::cgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
110+ mtfunc::cgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM. inp . dsp_count );
110111 }
111112#endif
112113 else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -157,7 +158,7 @@ void BlasConnector::gemm(const char transa,
157158#ifdef __DSP
158159 else if (device_type == base_device::AbacusDevice_t::DspDevice)
159160 {
160- mtfunc::zgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
161+ mtfunc::zgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM. inp . dsp_count );
161162 }
162163#endif
163164 else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -200,7 +201,7 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
200201 else if (device_type == base_device::AbacusDevice_t::DspDevice){
201202 mtfunc::sgemm_mth_ (&transb, &transa, &m, &n, &k,
202203 &alpha, a, &lda, b, &ldb,
203- &beta, c, &ldc, GlobalV::MY_RANK % 4 );
204+ &beta, c, &ldc, GlobalV::MY_RANK % PARAM. inp . dsp_count );
204205 }
205206#endif
206207#ifdef __CUDA
@@ -237,7 +238,7 @@ void BlasConnector::gemm_cm(const char transa,
237238#ifdef __DSP
238239 else if (device_type == base_device::AbacusDevice_t::DspDevice)
239240 {
240- mtfunc::dgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
241+ mtfunc::dgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM. inp . dsp_count );
241242 }
242243#endif
243244#ifdef __CUDA
@@ -276,7 +277,7 @@ void BlasConnector::gemm_cm(const char transa,
276277#ifdef __DSP
277278 else if (device_type == base_device::AbacusDevice_t::DspDevice)
278279 {
279- mtfunc::cgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
280+ mtfunc::cgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM. inp . dsp_count );
280281 }
281282#endif
282283#ifdef __CUDA
@@ -327,7 +328,7 @@ void BlasConnector::gemm_cm(const char transa,
327328#ifdef __DSP
328329 else if (device_type == base_device::AbacusDevice_t::DspDevice)
329330 {
330- mtfunc::zgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
331+ mtfunc::zgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM. inp . dsp_count );
331332 }
332333#endif
333334#ifdef __CUDA
0 commit comments