@@ -30,7 +30,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
3030 else if (device_type == base_device::AbacusDevice_t::DspDevice){
3131 mtfunc::sgemm_mth_ (&transb, &transa, &n, &m, &k,
3232 &alpha, b, &ldb, a, &lda,
33- &beta, c, &ldc, GlobalV::MY_RANK);
33+ &beta, c, &ldc, GlobalV::MY_RANK % 4 );
3434 }
3535#endif
3636#ifdef __CUDA
@@ -67,7 +67,7 @@ void BlasConnector::gemm(const char transa,
6767#ifdef __DSP
6868 else if (device_type == base_device::AbacusDevice_t::DspDevice)
6969 {
70- mtfunc::dgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK);
70+ mtfunc::dgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
7171 }
7272#endif
7373 else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -106,7 +106,7 @@ void BlasConnector::gemm(const char transa,
106106#ifdef __DSP
107107 else if (device_type == base_device::AbacusDevice_t::DspDevice)
108108 {
109- mtfunc::cgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK);
109+ mtfunc::cgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
110110 }
111111#endif
112112 else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -157,7 +157,7 @@ void BlasConnector::gemm(const char transa,
157157#ifdef __DSP
158158 else if (device_type == base_device::AbacusDevice_t::DspDevice)
159159 {
160- mtfunc::zgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK);
160+ mtfunc::zgemm_mth_ (&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
161161 }
162162#endif
163163 else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -200,7 +200,7 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
200200 else if (device_type == base_device::AbacusDevice_t::DspDevice){
201201 mtfunc::sgemm_mth_ (&transb, &transa, &m, &n, &k,
202202 &alpha, a, &lda, b, &ldb,
203- &beta, c, &ldc, GlobalV::MY_RANK);
203+ &beta, c, &ldc, GlobalV::MY_RANK % 4 );
204204 }
205205#endif
206206#ifdef __CUDA
@@ -237,7 +237,7 @@ void BlasConnector::gemm_cm(const char transa,
237237#ifdef __DSP
238238 else if (device_type == base_device::AbacusDevice_t::DspDevice)
239239 {
240- mtfunc::dgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK);
240+ mtfunc::dgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
241241 }
242242#endif
243243#ifdef __CUDA
@@ -276,7 +276,7 @@ void BlasConnector::gemm_cm(const char transa,
276276#ifdef __DSP
277277 else if (device_type == base_device::AbacusDevice_t::DspDevice)
278278 {
279- mtfunc::cgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK);
279+ mtfunc::cgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
280280 }
281281#endif
282282#ifdef __CUDA
@@ -327,7 +327,7 @@ void BlasConnector::gemm_cm(const char transa,
327327#ifdef __DSP
328328 else if (device_type == base_device::AbacusDevice_t::DspDevice)
329329 {
330- mtfunc::zgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK);
330+ mtfunc::zgemm_mth_ (&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % 4 );
331331 }
332332#endif
333333#ifdef __CUDA
0 commit comments