Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,12 @@ if(ENABLE_MPI)
endif()

if (USE_DSP)
target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY})
add_compile_definitions(__DSP)
target_link_libraries(${ABACUS_BIN_NAME} ${OMPI_LIBRARY1})
include_directories(${MTBLAS_FFT_DIR}/libmtblas/include)
include_directories(${MT_HOST_DIR}/include)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a)
endif()

find_package(Threads REQUIRED)
Expand Down Expand Up @@ -429,10 +433,8 @@ else()
find_package(Lapack REQUIRED)
include_directories(${FFTW3_INCLUDE_DIRS})
list(APPEND math_libs FFTW3::FFTW3 LAPACK::LAPACK BLAS::BLAS)

find_package(ScaLAPACK REQUIRED)
list(APPEND math_libs ScaLAPACK::ScaLAPACK)

if(USE_OPENMP)
list(APPEND math_libs FFTW3::FFTW3_OMP)
endif()
Expand Down
7 changes: 7 additions & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ if(USE_ROCM)
)
endif()

if(USE_DSP)
list(APPEND device_srcs
module_base/kernels/dsp/dsp_connector.cpp
)
endif()


add_library(device OBJECT ${device_srcs})

if(USE_CUDA)
Expand Down
5 changes: 4 additions & 1 deletion source/module_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ add_library(
)

target_link_libraries(base PUBLIC container)

if (USE_DSP)
target_link_libraries(base PUBLIC ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtblas.a)
target_link_libraries(base PUBLIC ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtblasdev.a)
endif()
add_subdirectory(module_container)

if(ENABLE_COVERAGE)
Expand Down
201 changes: 201 additions & 0 deletions source/module_base/kernels/dsp/dsp_connector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
#include "dsp_connector.h"
#include <iostream>
#include <complex>

extern "C"
{
#define complex_double ignore_complex_double
#include <mt_hthread_blas.h> // MTBLAS_TRANSPOSE etc
#undef complex_double
#include <mtblas_interface.h> // gemm
}

void dspInitHandle(int id){
mt_blas_init(id);
std::cout << " ** DSP inited on cluster "<< id << " **" << std::endl;
} // Use this at the beginning of the program to start a dsp cluster

void dspDestoryHandle(int id){
hthread_dev_close(id);
std::cout << " ** DSP closed on cluster "<< id << " **" << std::endl;
} // Close dsp cluster at the end


MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) {
switch (blasTrans[0]) {
case 'N':
case 'n':
return MtblasNoTrans;
case 'T':
case 't':
return MtblasTrans;
case 'C':
case 'c':
return MtblasConjTrans;
default:
std::cout << "Invalid BLAS transpose parameter!! Use default instead." << std::endl;
return MtblasNoTrans;
}
} // Used to convert normal transpost char to mtblas transpose flag


void* malloc_ht(size_t bytes, int cluster_id)
{
//std::cout << "MALLOC " << cluster_id;
void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW);
//std::cout << ptr << " SUCCEED" << std::endl;;
return ptr;
}

// Used to replace original malloc

void free_ht(void* ptr)
{
//std::cout << "FREE " << ptr;
hthread_free(ptr);
//std::cout << " FREE SUCCEED" << std::endl;
}

// Used to replace original free

void sgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const float *alpha, const float *a, const int *lda,
const float *b, const int *ldb, const float *beta,
float *c, const int *ldc, int cluster_id)
{
mtblas_sgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),convertBLASTranspose(transb),
*m,*n,*k,
*alpha, a, *lda,
b, *ldb, *beta,
c, *ldc, cluster_id
);
} // zgemm that needn't malloc_ht or free_ht

void dgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const double *alpha, const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc, int cluster_id)
{
mtblas_dgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),convertBLASTranspose(transb),
*m,*n,*k,
*alpha, a, *lda,
b, *ldb, *beta,
c, *ldc, cluster_id
);
} // cgemm that needn't malloc_ht or free_ht

void zgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
std::complex<double> *c, const int *ldc, int cluster_id)
{
mtblas_zgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),convertBLASTranspose(transb),
*m,*n,*k,
(const void*)alpha, (const void*)a, *lda,
(const void*)b, *ldb, (const void*)beta,
(void*)c, *ldc, cluster_id
);
} // zgemm that needn't malloc_ht or free_ht

void cgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
std::complex<float> *c, const int *ldc, int cluster_id)
{
mtblas_cgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),convertBLASTranspose(transb),
*m,*n,*k,
(const void*)alpha, (const void*)a, *lda,
(const void*)b, *ldb, (const void*)beta,
(void*)c, *ldc, cluster_id
);
} // cgemm that needn't malloc_ht or free_ht

// Used to replace original free

void sgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const float *alpha, const float *a, const int *lda,
const float *b, const int *ldb, const float *beta,
float *c, const int *ldc, int cluster_id)
{
mt_hthread_sgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),convertBLASTranspose(transb),
*m,*n,*k,
*alpha, a, *lda,
b, *ldb, *beta,
c, *ldc, cluster_id
);
} // zgemm that needn't malloc_ht or free_ht

void dgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const double *alpha, const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc, int cluster_id)
{
mt_hthread_dgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),convertBLASTranspose(transb),
*m,*n,*k,
*alpha, a, *lda,
b, *ldb, *beta,
c, *ldc, cluster_id
);
} // cgemm that needn't malloc_ht or free_ht

void zgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<double> *alpha,
const std::complex<double> *a,
const int *lda,
const std::complex<double> *b,
const int *ldb,
const std::complex<double> *beta,
std::complex<double> *c,
const int *ldc,
int cluster_id)
{
std::complex<double>* alp = (std::complex<double>*) malloc_ht(sizeof(std::complex<double>), cluster_id);
*alp = *alpha;
std::complex<double>* bet = (std::complex<double>*) malloc_ht(sizeof(std::complex<double>), cluster_id);
*bet = *beta;
mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),convertBLASTranspose(transb),
*m,*n,*k,
alp, a, *lda,
b, *ldb, bet,
c, *ldc, cluster_id
);


} // zgemm that needn't malloc_ht or free_ht

void cgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
std::complex<float> *c, const int *ldc, int cluster_id)
{
std::complex<float>* alp = (std::complex<float>*) malloc_ht(sizeof(std::complex<float>), cluster_id);
*alp = *alpha;
std::complex<float>* bet = (std::complex<float>*) malloc_ht(sizeof(std::complex<float>), cluster_id);
*bet = *beta;

mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),convertBLASTranspose(transb),
*m,*n,*k,
(const void*)alp, (const void*)a, *lda,
(const void*)b, *ldb, (const void*)bet,
(void*)c, *ldc, cluster_id
);

free_ht(alp);
free_ht(bet);
} // cgemm that needn't malloc_ht or free_ht
12 changes: 12 additions & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ if (USE_ROCM)
module_fft/fft_rocm.cpp
)
endif()
if (USE_DSP)
list (APPEND FFT_SRC
module_fft/fft_dsp.cpp
module_fft/fft_dsp_float.cpp
pw_transform_k_dsp.cpp)
endif()

list(APPEND objects
pw_basis.cpp
Expand All @@ -36,6 +42,12 @@ add_library(
${objects}
)

if (USE_DSP)
target_link_libraries(planewave PRIVATE
${MTBLAS_FFT_DIR}/libmtblas/lib/libmtfft.a)
target_compile_definitions( planewave PUBLIC
FFT_DAT_DIR="${MTBLAS_FFT_DIR}/datfile/mt_fft_blas.dat")
endif()
if(ENABLE_COVERAGE)
add_coverage(planewave)
endif()
Expand Down
17 changes: 14 additions & 3 deletions source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

#include "module_base/module_device/device.h"
#include "module_base/module_device/memory_op.h"
#include "module_base/tool_quit.h"
#if defined(__CUDA)
#include "fft_cuda.h"
#endif
#if defined(__ROCM)
#include "fft_rocm.h"
#endif

#if defined(__DSP)
#include "fft_dsp.h"
#endif
template<typename FFT_BASE, typename... Args>
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
{
Expand Down Expand Up @@ -40,7 +43,7 @@ void FFT_Bundle::initfft(int nx_in,
bool xprime_in ,
bool mpifft_in)
{
assert(this->device=="cpu" || this->device=="gpu");
assert(this->device=="cpu" || this->device=="gpu" || this->device=="dsp");
assert(this->precision=="single" || this->precision=="double" || this->precision=="mixing");

if (this->precision=="single")
Expand All @@ -62,7 +65,15 @@ void FFT_Bundle::initfft(int nx_in,
{
double_flag = true;
}

#if defined(__DSP)
if (device=="dsp")
{
if (float_flag)
ModuleBase::WARNING_QUIT("device","now dsp fft is not support for the float type");
fft_double=make_unique<FFT_DSP<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
}
#endif
if (device=="cpu")
{
fft_float = make_unique<FFT_CPU<float>>(this->fft_mode);
Expand Down
Loading
Loading