Skip to content

Commit 934e9e9

Browse files
committed
Link mtblas library
1 parent 1a11dac commit 934e9e9

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ option(ENABLE_RAPIDJSON "Enable rapid-json usage." OFF)
3939
option(ENABLE_CNPY "Enable cnpy usage." OFF)
4040
option(ENABLE_PEXSI "Enable support for PEXSI." OFF)
4141
option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF)
42+
option(ENABLE_DSP "Enable DSP usage." OFF)
4243

4344
# enable json support
4445
if(ENABLE_RAPIDJSON)
@@ -119,6 +120,12 @@ elseif(ENABLE_LCAO AND NOT ENABLE_MPI)
119120
set(ABACUS_BIN_NAME abacus_serial)
120121
endif()
121122

123+
if (USE_DSP)
124+
set(USE_ELPA OFF)
125+
set(ENABLE_LCAO OFF)
126+
set(ABACUS_BIN_NAME abacus_dsp)
127+
endif()
128+
122129
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
123130

124131
if(ENABLE_COVERAGE)
@@ -240,6 +247,10 @@ if(ENABLE_MPI)
240247
list(APPEND math_libs MPI::MPI_CXX)
241248
endif()
242249

250+
if (USE_DSP)
251+
target_link_libraries(${ABACUS_BIN_NAME} DIR_MTBLAS_LIBRARY)
252+
endif()
253+
243254
find_package(Threads REQUIRED)
244255
target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads)
245256

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#ifdef __DSP
2+
3+
// Base dsp functions
4+
void createMtblasHandle(int id);
5+
void destroyMtblasHandle();
6+
void *malloc_ht(size_t bytes);
7+
void free_ht(void* ptr);
8+
9+
10+
// mtblas functions
11+
12+
void sgemm_mt_(const char *transa, const char *transb,
13+
const int *m, const int *n, const int *k,
14+
const float *alpha, const float *a, const int *lda,
15+
const float *b, const int *ldb, const const float *beta,
16+
const float *c, const int *ldc);
17+
18+
void dgemm_mt_(const char *transa, const char *transb,
19+
const int *m, const int *n, const int *k,
20+
const double *alpha,const double *a, const int *lda,
21+
const double *b, const int *ldb, const double *beta,
22+
const double *c, const int *ldc);
23+
24+
void zgemm_mt_(const char *transa, const char *transb,
25+
const int *m, const int *n, const int *k,
26+
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
27+
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
28+
std::complex<double> *c, const int *ldc);
29+
30+
void cgemm_mt_(const char *transa, const char *transb,
31+
const int *m, const int *n, const int *k,
32+
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
33+
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
34+
std::complex<float> *c, const int *ldc);
35+
36+
37+
void sgemm_mth_(const char *transa, const char *transb,
38+
const int *m, const int *n, const int *k,
39+
const float *alpha, const float *a, const int *lda,
40+
const float *b, const int *ldb, const const float *beta,
41+
const float *c, const int *ldc);
42+
43+
void dgemm_mth_(const char *transa, const char *transb,
44+
const int *m, const int *n, const int *k,
45+
const double *alpha,const double *a, const int *lda,
46+
const double *b, const int *ldb, const double *beta,
47+
const double *c, const int *ldc);
48+
49+
void zgemm_mth_(const char *transa, const char *transb,
50+
const int *m, const int *n, const int *k,
51+
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
52+
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
53+
std::complex<double> *c, const int *ldc);
54+
55+
void cgemm_mth_(const char *transa, const char *transb,
56+
const int *m, const int *n, const int *k,
57+
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
58+
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
59+
std::complex<float> *c, const int *ldc);
60+
61+
//#define zgemm_ zgemm_mt
62+
63+
#endif

0 commit comments

Comments
 (0)