Skip to content

Commit 72d9d1d

Browse files
authored
Feature: MPI available on ABACUS DSP version (#5351)
* Fix parallel function * Fix parallel usage * Temporarily remove memory_op porting
1 parent 729059f commit 72d9d1d

File tree

4 files changed

+18
-28
lines changed

4 files changed

+18
-28
lines changed

source/module_base/blas_connector.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#ifdef __DSP
44
#include "module_base/kernels/dsp/dsp_connector.h"
5+
#include "module_base/global_variable.h"
56
#endif
67

78
void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
@@ -94,7 +95,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
9495
else if (device_type == base_device::AbacusDevice_t::DspDevice){
9596
sgemm_mt_(&transb, &transa, &n, &m, &k,
9697
&alpha, b, &ldb, a, &lda,
97-
&beta, c, &ldc);
98+
&beta, c, &ldc, GlobalV::MY_RANK);
9899
}
99100
#endif
100101
}
@@ -112,7 +113,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
112113
else if (device_type == base_device::AbacusDevice_t::DspDevice){
113114
dgemm_mt_(&transb, &transa, &n, &m, &k,
114115
&alpha, b, &ldb, a, &lda,
115-
&beta, c, &ldc);
116+
&beta, c, &ldc, GlobalV::MY_RANK);
116117
}
117118
#endif
118119
}
@@ -130,7 +131,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
130131
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
131132
cgemm_mt_(&transb, &transa, &n, &m, &k,
132133
&alpha, b, &ldb, a, &lda,
133-
&beta, c, &ldc);
134+
&beta, c, &ldc, GlobalV::MY_RANK);
134135
}
135136
#endif
136137
}
@@ -148,7 +149,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
148149
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
149150
zgemm_mt_(&transb, &transa, &n, &m, &k,
150151
&alpha, b, &ldb, a, &lda,
151-
&beta, c, &ldc);
152+
&beta, c, &ldc, GlobalV::MY_RANK);
152153
}
153154
#endif
154155
}

source/module_base/kernels/dsp/dsp_connector.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
// Base dsp functions
66
void dspInitHandle(int id);
7-
void dspDestoryHandle();
8-
void *malloc_ht(size_t bytes);
7+
void dspDestoryHandle(int id);
8+
void *malloc_ht(size_t bytes, int cluster_id);
99
void free_ht(void* ptr);
1010

1111

@@ -15,50 +15,50 @@ void sgemm_mt_(const char *transa, const char *transb,
1515
const int *m, const int *n, const int *k,
1616
const float *alpha, const float *a, const int *lda,
1717
const float *b, const int *ldb, const float *beta,
18-
float *c, const int *ldc);
18+
float *c, const int *ldc, int cluster_id);
1919

2020
void dgemm_mt_(const char *transa, const char *transb,
2121
const int *m, const int *n, const int *k,
2222
const double *alpha,const double *a, const int *lda,
2323
const double *b, const int *ldb, const double *beta,
24-
double *c, const int *ldc);
24+
double *c, const int *ldc, int cluster_id);
2525

2626
void zgemm_mt_(const char *transa, const char *transb,
2727
const int *m, const int *n, const int *k,
2828
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
2929
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
30-
std::complex<double> *c, const int *ldc);
30+
std::complex<double> *c, const int *ldc, int cluster_id);
3131

3232
void cgemm_mt_(const char *transa, const char *transb,
3333
const int *m, const int *n, const int *k,
3434
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
3535
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
36-
std::complex<float> *c, const int *ldc);
36+
std::complex<float> *c, const int *ldc, int cluster_id);
3737

3838

3939
void sgemm_mth_(const char *transa, const char *transb,
4040
const int *m, const int *n, const int *k,
4141
const float *alpha, const float *a, const int *lda,
4242
const float *b, const int *ldb, const float *beta,
43-
float *c, const int *ldc);
43+
float *c, const int *ldc, int cluster_id);
4444

4545
void dgemm_mth_(const char *transa, const char *transb,
4646
const int *m, const int *n, const int *k,
4747
const double *alpha,const double *a, const int *lda,
4848
const double *b, const int *ldb, const double *beta,
49-
double *c, const int *ldc);
49+
double *c, const int *ldc, int cluster_id);
5050

5151
void zgemm_mth_(const char *transa, const char *transb,
5252
const int *m, const int *n, const int *k,
5353
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
5454
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
55-
std::complex<double> *c, const int *ldc);
55+
std::complex<double> *c, const int *ldc, int cluster_id);
5656

5757
void cgemm_mth_(const char *transa, const char *transb,
5858
const int *m, const int *n, const int *k,
5959
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
6060
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
61-
std::complex<float> *c, const int *ldc);
61+
std::complex<float> *c, const int *ldc, int cluster_id);
6262

6363
//#define zgemm_ zgemm_mt
6464

source/module_base/module_device/memory_op.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "module_base/tool_threading.h"
55
#ifdef __DSP
66
#include "module_base/kernels/dsp/dsp_connector.h"
7+
#include "module_base/global_variable.h"
78
#endif
89

910
#include <complex>
@@ -21,17 +22,9 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_CPU>
2122
{
2223
if (arr != nullptr)
2324
{
24-
#ifdef __DSP
25-
free_ht(arr);
26-
#else
2725
free(arr);
28-
#endif
2926
}
30-
#ifdef __DSP
31-
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size);
32-
#else
3327
arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size);
34-
#endif
3528
std::string record_string;
3629
if (record_in != nullptr)
3730
{
@@ -103,11 +96,7 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_CPU>
10396
{
10497
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
10598
{
106-
#ifdef __DSP
107-
free_ht(arr);
108-
#else
10999
free(arr);
110-
#endif
111100
}
112101
};
113102

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
7373
#endif
7474
#ifdef __DSP
7575
std::cout << " ** Initializing DSP Hardware..." << std::endl;
76-
dspInitHandle(GlobalV::MY_RANK % 4);
76+
dspInitHandle(GlobalV::MY_RANK);
7777
#endif
7878
}
7979

@@ -102,7 +102,7 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
102102
}
103103
#ifdef __DSP
104104
std::cout << " ** Closing DSP Hardware..." << std::endl;
105-
dspDestoryHandle();
105+
dspDestoryHandle(GlobalV::MY_RANK);
106106
#endif
107107
if (PARAM.inp.precision == "single")
108108
{

0 commit comments

Comments
 (0)