Skip to content

Commit d82c508

Browse files
committed
Fix parallel function
1 parent d7150d0 commit d82c508

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

source/module_base/blas_connector.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#ifdef __DSP
44
#include "module_base/kernels/dsp/dsp_connector.h"
5+
#include "module_base/global_variable.h"
56
#endif
67

78
void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
@@ -94,7 +95,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
9495
else if (device_type == base_device::AbacusDevice_t::DspDevice){
9596
sgemm_mt_(&transb, &transa, &n, &m, &k,
9697
&alpha, b, &ldb, a, &lda,
97-
&beta, c, &ldc);
98+
&beta, c, &ldc, GlobalV::MY_RANK);
9899
}
99100
#endif
100101
}
@@ -112,7 +113,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
112113
else if (device_type == base_device::AbacusDevice_t::DspDevice){
113114
dgemm_mt_(&transb, &transa, &n, &m, &k,
114115
&alpha, b, &ldb, a, &lda,
115-
&beta, c, &ldc);
116+
&beta, c, &ldc, GlobalV::MY_RANK);
116117
}
117118
#endif
118119
}
@@ -130,7 +131,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
130131
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
131132
cgemm_mt_(&transb, &transa, &n, &m, &k,
132133
&alpha, b, &ldb, a, &lda,
133-
&beta, c, &ldc);
134+
&beta, c, &ldc, GlobalV::MY_RANK);
134135
}
135136
#endif
136137
}
@@ -148,7 +149,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
148149
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
149150
zgemm_mt_(&transb, &transa, &n, &m, &k,
150151
&alpha, b, &ldb, a, &lda,
151-
&beta, c, &ldc);
152+
&beta, c, &ldc, GlobalV::MY_RANK);
152153
}
153154
#endif
154155
}

source/module_base/kernels/dsp/dsp_connector.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,50 +15,50 @@ void sgemm_mt_(const char *transa, const char *transb,
1515
const int *m, const int *n, const int *k,
1616
const float *alpha, const float *a, const int *lda,
1717
const float *b, const int *ldb, const float *beta,
18-
float *c, const int *ldc);
18+
float *c, const int *ldc, int cluster_id);
1919

2020
void dgemm_mt_(const char *transa, const char *transb,
2121
const int *m, const int *n, const int *k,
2222
const double *alpha,const double *a, const int *lda,
2323
const double *b, const int *ldb, const double *beta,
24-
double *c, const int *ldc);
24+
double *c, const int *ldc, int cluster_id);
2525

2626
void zgemm_mt_(const char *transa, const char *transb,
2727
const int *m, const int *n, const int *k,
2828
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
2929
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
30-
std::complex<double> *c, const int *ldc);
30+
std::complex<double> *c, const int *ldc, int cluster_id);
3131

3232
void cgemm_mt_(const char *transa, const char *transb,
3333
const int *m, const int *n, const int *k,
3434
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
3535
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
36-
std::complex<float> *c, const int *ldc);
36+
std::complex<float> *c, const int *ldc, int cluster_id);
3737

3838

3939
void sgemm_mth_(const char *transa, const char *transb,
4040
const int *m, const int *n, const int *k,
4141
const float *alpha, const float *a, const int *lda,
4242
const float *b, const int *ldb, const float *beta,
43-
float *c, const int *ldc);
43+
float *c, const int *ldc, int cluster_id);
4444

4545
void dgemm_mth_(const char *transa, const char *transb,
4646
const int *m, const int *n, const int *k,
4747
const double *alpha,const double *a, const int *lda,
4848
const double *b, const int *ldb, const double *beta,
49-
double *c, const int *ldc);
49+
double *c, const int *ldc, int cluster_id);
5050

5151
void zgemm_mth_(const char *transa, const char *transb,
5252
const int *m, const int *n, const int *k,
5353
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
5454
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
55-
std::complex<double> *c, const int *ldc);
55+
std::complex<double> *c, const int *ldc, int cluster_id);
5656

5757
void cgemm_mth_(const char *transa, const char *transb,
5858
const int *m, const int *n, const int *k,
5959
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
6060
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
61-
std::complex<float> *c, const int *ldc);
61+
std::complex<float> *c, const int *ldc, int cluster_id);
6262

6363
//#define zgemm_ zgemm_mt
6464

0 commit comments

Comments
 (0)