Skip to content

Commit f039250

Browse files
Feature: Porting abacus to DSP hardware (mtblas part) (#5301)
* Link mtblas library * Add mtblas gemm kernel usage * Finish memory_op on dsp * Update CMakeLists * Add compilation script * Fix warnings * Fix install script * Initialize DSP hardware * Replace gemm in math_kernel * Fix CMakeLists Bug * Fix bugs #1 * Fix bug 2 * Fix link to shared library error * Stop use gemm_mt globally * Modify op usage * Fix bug * Fix template usage * Fix compilation * Replace all dav_subspace gemm kernels --------- Co-authored-by: Mohan Chen <[email protected]>
1 parent 2af2095 commit f039250

File tree

10 files changed

+251
-11
lines changed

10 files changed

+251
-11
lines changed

CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ option(ENABLE_RAPIDJSON "Enable rapid-json usage." OFF)
3939
option(ENABLE_CNPY "Enable cnpy usage." OFF)
4040
option(ENABLE_PEXSI "Enable support for PEXSI." OFF)
4141
option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF)
42+
option(USE_DSP "Enable DSP usage." OFF)
4243

4344
# enable json support
4445
if(ENABLE_RAPIDJSON)
@@ -119,6 +120,12 @@ elseif(ENABLE_LCAO AND NOT ENABLE_MPI)
119120
set(ABACUS_BIN_NAME abacus_serial)
120121
endif()
121122

123+
if (USE_DSP)
124+
set(USE_ELPA OFF)
125+
set(ENABLE_LCAO OFF)
126+
set(ABACUS_BIN_NAME abacus_dsp)
127+
endif()
128+
122129
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
123130

124131
if(ENABLE_COVERAGE)
@@ -240,6 +247,11 @@ if(ENABLE_MPI)
240247
list(APPEND math_libs MPI::MPI_CXX)
241248
endif()
242249

250+
if (USE_DSP)
251+
target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY})
252+
add_compile_definitions(__DSP)
253+
endif()
254+
243255
find_package(Threads REQUIRED)
244256
target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads)
245257

install_dsp.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
CXX=mpicxx \
2+
cmake -B build \
3+
-DUSE_DSP=ON \
4+
-DENABLE_LCAO=OFF \
5+
-DFFTW3_DIR=/vol8/appsoftware/fftw/ \
6+
-DFFTW3_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3.so \
7+
-DFFTW3_OMP_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3_omp.so \
8+
-DFFTW3_FLOAT_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3f.so \
9+
-DLAPACK_DIR=/vol8/appsoftware/openblas/0.3.21/lib \
10+
-DDIR_MTBLAS_LIBRARY=/vol8/home/dptech_zyz1/develop/packages/libmtblas_abacus.so

source/module_base/blas_connector.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include "blas_connector.h"
22

3+
#ifdef __DSP
4+
#include "module_base/kernels/dsp/dsp_connector.h"
5+
#endif
6+
37
void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
48
{
59
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
@@ -64,13 +68,15 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
6468
{
6569
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
6670
return sdot_(&n, X, &incX, Y, &incY);
71+
return sdot_(&n, X, &incX, Y, &incY);
6772
}
6873
}
6974

7075
double BlasConnector::dot( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type)
7176
{
7277
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
7378
return ddot_(&n, X, &incX, Y, &incY);
79+
return ddot_(&n, X, &incX, Y, &incY);
7480
}
7581
}
7682

@@ -83,7 +89,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
8389
sgemm_(&transb, &transa, &n, &m, &k,
8490
&alpha, b, &ldb, a, &lda,
8591
&beta, c, &ldc);
86-
}
92+
}
93+
#ifdef __DSP
94+
else if (device_type == base_device::AbacusDevice_t::DspDevice){
95+
sgemm_mt_(&transb, &transa, &n, &m, &k,
96+
&alpha, b, &ldb, a, &lda,
97+
&beta, c, &ldc);
98+
}
99+
#endif
87100
}
88101

89102
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -94,7 +107,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
94107
dgemm_(&transb, &transa, &n, &m, &k,
95108
&alpha, b, &ldb, a, &lda,
96109
&beta, c, &ldc);
97-
}
110+
}
111+
#ifdef __DSP
112+
else if (device_type == base_device::AbacusDevice_t::DspDevice){
113+
dgemm_mt_(&transb, &transa, &n, &m, &k,
114+
&alpha, b, &ldb, a, &lda,
115+
&beta, c, &ldc);
116+
}
117+
#endif
98118
}
99119

100120
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -105,7 +125,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
105125
cgemm_(&transb, &transa, &n, &m, &k,
106126
&alpha, b, &ldb, a, &lda,
107127
&beta, c, &ldc);
108-
}
128+
}
129+
#ifdef __DSP
130+
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
131+
cgemm_mt_(&transb, &transa, &n, &m, &k,
132+
&alpha, b, &ldb, a, &lda,
133+
&beta, c, &ldc);
134+
}
135+
#endif
109136
}
110137

111138
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -116,7 +143,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
116143
zgemm_(&transb, &transa, &n, &m, &k,
117144
&alpha, b, &ldb, a, &lda,
118145
&beta, c, &ldc);
119-
}
146+
}
147+
#ifdef __DSP
148+
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
149+
zgemm_mt_(&transb, &transa, &n, &m, &k,
150+
&alpha, b, &ldb, a, &lda,
151+
&beta, c, &ldc);
152+
}
153+
#endif
120154
}
121155

122156
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -152,6 +186,7 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
152186
{
153187
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
154188
return snrm2_( &n, X, &incX );
189+
return snrm2_( &n, X, &incX );
155190
}
156191
}
157192

@@ -160,6 +195,7 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
160195
{
161196
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
162197
return dnrm2_( &n, X, &incX );
198+
return dnrm2_( &n, X, &incX );
163199
}
164200
}
165201

@@ -168,6 +204,7 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
168204
{
169205
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
170206
return dznrm2_( &n, X, &incX );
207+
return dznrm2_( &n, X, &incX );
171208
}
172209
}
173210

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef DSP_CONNECTOR_H
2+
#define DSP_CONNECTOR_H
3+
#ifdef __DSP
4+
5+
// Base dsp functions
6+
void dspInitHandle(int id);
7+
void dspDestoryHandle();
8+
void *malloc_ht(size_t bytes);
9+
void free_ht(void* ptr);
10+
11+
12+
// mtblas functions
13+
14+
void sgemm_mt_(const char *transa, const char *transb,
15+
const int *m, const int *n, const int *k,
16+
const float *alpha, const float *a, const int *lda,
17+
const float *b, const int *ldb, const float *beta,
18+
float *c, const int *ldc);
19+
20+
void dgemm_mt_(const char *transa, const char *transb,
21+
const int *m, const int *n, const int *k,
22+
const double *alpha,const double *a, const int *lda,
23+
const double *b, const int *ldb, const double *beta,
24+
double *c, const int *ldc);
25+
26+
void zgemm_mt_(const char *transa, const char *transb,
27+
const int *m, const int *n, const int *k,
28+
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
29+
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
30+
std::complex<double> *c, const int *ldc);
31+
32+
void cgemm_mt_(const char *transa, const char *transb,
33+
const int *m, const int *n, const int *k,
34+
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
35+
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
36+
std::complex<float> *c, const int *ldc);
37+
38+
39+
void sgemm_mth_(const char *transa, const char *transb,
40+
const int *m, const int *n, const int *k,
41+
const float *alpha, const float *a, const int *lda,
42+
const float *b, const int *ldb, const float *beta,
43+
float *c, const int *ldc);
44+
45+
void dgemm_mth_(const char *transa, const char *transb,
46+
const int *m, const int *n, const int *k,
47+
const double *alpha,const double *a, const int *lda,
48+
const double *b, const int *ldb, const double *beta,
49+
double *c, const int *ldc);
50+
51+
void zgemm_mth_(const char *transa, const char *transb,
52+
const int *m, const int *n, const int *k,
53+
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
54+
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
55+
std::complex<double> *c, const int *ldc);
56+
57+
void cgemm_mth_(const char *transa, const char *transb,
58+
const int *m, const int *n, const int *k,
59+
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
60+
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
61+
std::complex<float> *c, const int *ldc);
62+
63+
//#define zgemm_ zgemm_mt
64+
65+
#endif
66+
#endif

source/module_base/module_device/memory_op.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
#include "module_base/memory.h"
44
#include "module_base/tool_threading.h"
5+
#ifdef __DSP
6+
#include "module_base/kernels/dsp/dsp_connector.h"
7+
#endif
58

69
#include <complex>
710
#include <cstring>
@@ -18,9 +21,17 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_CPU>
1821
{
1922
if (arr != nullptr)
2023
{
24+
#ifdef __DSP
25+
free_ht(arr);
26+
#else
2127
free(arr);
28+
#endif
2229
}
30+
#ifdef __DSP
31+
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size);
32+
#else
2333
arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size);
34+
#endif
2435
std::string record_string;
2536
if (record_in != nullptr)
2637
{
@@ -92,7 +103,11 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_CPU>
92103
{
93104
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
94105
{
106+
#ifdef __DSP
107+
free_ht(arr);
108+
#else
95109
free(arr);
110+
#endif
96111
}
97112
};
98113

source/module_base/module_device/types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ enum AbacusDevice_t
1212
UnKnown,
1313
CpuDevice,
1414
GpuDevice,
15+
DspDevice
1516
};
1617

1718
} // namespace base_device

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
#include <ATen/kernels/blas.h>
5050
#include <ATen/kernels/lapack.h>
5151

52+
#ifdef __DSP
53+
#include "module_base/kernels/dsp/dsp_connector.h"
54+
#endif
55+
5256
namespace ModuleESolver
5357
{
5458

@@ -67,6 +71,10 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
6771
container::kernels::createGpuSolverHandle();
6872
}
6973
#endif
74+
#ifdef __DSP
75+
std::cout << " ** Initializing DSP Hardware..." << std::endl;
76+
dspInitHandle(GlobalV::MY_RANK % 4);
77+
#endif
7078
}
7179

7280
template <typename T, typename Device>
@@ -92,7 +100,10 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
92100
#endif
93101
delete reinterpret_cast<psi::Psi<T, Device>*>(this->kspw_psi);
94102
}
95-
103+
#ifdef __DSP
104+
std::cout << " ** Closing DSP Hardware..." << std::endl;
105+
dspDestoryHandle();
106+
#endif
96107
if (PARAM.inp.precision == "single")
97108
{
98109
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,12 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
181181
// updata eigenvectors of Hamiltonian
182182
setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax);
183183

184-
gemm_op<T, Device>()(this->ctx,
184+
#ifdef __DSP
185+
gemm_op_mt<T, Device>()
186+
#else
187+
gemm_op<T, Device>()
188+
#endif
189+
(this->ctx,
185190
'N',
186191
'N',
187192
this->dim,
@@ -262,7 +267,12 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
262267
}
263268
}
264269

265-
gemm_op<T, Device>()(this->ctx,
270+
#ifdef __DSP
271+
gemm_op_mt<T, Device>()
272+
#else
273+
gemm_op<T, Device>()
274+
#endif
275+
(this->ctx,
266276
'N',
267277
'N',
268278
this->dim,
@@ -302,7 +312,12 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
302312
delmem_real_op()(this->ctx, e_temp_hd);
303313
}
304314

305-
gemm_op<T, Device>()(this->ctx,
315+
#ifdef __DSP
316+
gemm_op_mt<T, Device>()
317+
#else
318+
gemm_op<T, Device>()
319+
#endif
320+
(this->ctx,
306321
'N',
307322
'N',
308323
this->dim,
@@ -386,7 +401,12 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
386401
{
387402
ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem");
388403

389-
gemm_op<T, Device>()(this->ctx,
404+
#ifdef __DSP
405+
gemm_op_mt<T, Device>()
406+
#else
407+
gemm_op<T, Device>()
408+
#endif
409+
(this->ctx,
390410
'C',
391411
'N',
392412
nbase + notconv,
@@ -401,7 +421,12 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
401421
&hcc[nbase * this->nbase_x],
402422
this->nbase_x);
403423

404-
gemm_op<T, Device>()(this->ctx,
424+
#ifdef __DSP
425+
gemm_op_mt<T, Device>()
426+
#else
427+
gemm_op<T, Device>()
428+
#endif
429+
(this->ctx,
405430
'C',
406431
'N',
407432
nbase + notconv,
@@ -603,7 +628,12 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
603628
{
604629
ModuleBase::timer::tick("Diago_DavSubspace", "refresh");
605630

606-
gemm_op<T, Device>()(this->ctx,
631+
#ifdef __DSP
632+
gemm_op_mt<T, Device>()
633+
#else
634+
gemm_op<T, Device>()
635+
#endif
636+
(this->ctx,
607637
'N',
608638
'N',
609639
this->dim,

0 commit comments

Comments
 (0)