diff --git a/CMakeLists.txt b/CMakeLists.txt index d1487af838..fb02f66809 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -257,8 +257,12 @@ if(ENABLE_MPI) endif() if (USE_DSP) - target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY}) add_compile_definitions(__DSP) + target_link_libraries(${ABACUS_BIN_NAME} ${OMPI_LIBRARY1}) + include_directories(${MTBLAS_FFT_DIR}/libmtblas/include) + include_directories(${MT_HOST_DIR}/include) + target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a) + target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a) endif() find_package(Threads REQUIRED) @@ -429,10 +433,8 @@ else() find_package(Lapack REQUIRED) include_directories(${FFTW3_INCLUDE_DIRS}) list(APPEND math_libs FFTW3::FFTW3 LAPACK::LAPACK BLAS::BLAS) - find_package(ScaLAPACK REQUIRED) list(APPEND math_libs ScaLAPACK::ScaLAPACK) - if(USE_OPENMP) list(APPEND math_libs FFTW3::FFTW3_OMP) endif() diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 769138b096..5acce9103f 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -104,6 +104,13 @@ if(USE_ROCM) ) endif() +if(USE_DSP) + list(APPEND device_srcs + module_base/kernels/dsp/dsp_connector.cpp + ) +endif() + + add_library(device OBJECT ${device_srcs}) if(USE_CUDA) diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index ecbdedcf6a..e6b016b311 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -65,7 +65,10 @@ add_library( ) target_link_libraries(base PUBLIC container) - +if (USE_DSP) + target_link_libraries(base PUBLIC ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtblas.a) + target_link_libraries(base PUBLIC ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtblasdev.a) +endif() add_subdirectory(module_container) if(ENABLE_COVERAGE) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index b422969ac5..5ccb7fc369 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -226,7 +226,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - sgemm_mth_(&transb, &transa, &n, &m, &k, + mtfunc::sgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -240,79 +240,136 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } } -void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, - const double alpha, const double *a, const int lda, const double *b, const int ldb, - const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) +void BlasConnector::gemm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const double alpha, + const double* a, + const int lda, + const double* b, + const int ldb, + const double beta, + double* c, + const int ldc, + base_device::AbacusDevice_t device_type) { - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dgemm_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); - } + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + dgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice){ - dgemm_mth_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); -#endif - } -} - -void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - cgemm_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck( + cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); +#endif + } +} + +void BlasConnector::gemm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) { - cgemm_mth_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc)); -#endif - } -} - -void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zgemm_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + n, + m, + k, + (float2*)&alpha, + (float2*)b, + ldb, + (float2*)a, + lda, + (float2*)&beta, + (float2*)c, + ldc)); +#endif + } +} + +void BlasConnector::gemm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + zgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) { - zgemm_mth_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc)); -#endif - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + n, + m, + k, + (double2*)&alpha, + (double2*)b, + ldb, + (double2*)a, + lda, + (double2*)&beta, + (double2*)c, + ldc)); +#endif + } } // Col-Major part @@ -327,7 +384,7 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - sgemm_mth_(&transb, &transa, &m, &n, &k, + mtfunc::sgemm_mth_(&transb, &transa, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -341,79 +398,136 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c } } -void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, - const double alpha, const double *a, const int lda, const double *b, const int ldb, - const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) +void BlasConnector::gemm_cm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const double alpha, + const double* a, + const int lda, + const double* b, + const int ldb, + const double beta, + double* c, + const int ldc, + base_device::AbacusDevice_t device_type) { - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dgemm_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice){ - dgemm_mth_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); -#endif - } -} - -void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - cgemm_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::dgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck( + cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); +#endif + } +} + +void BlasConnector::gemm_cm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) { - cgemm_mth_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); -#endif - } -} - -void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zgemm_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + m, + n, + k, + (float2*)&alpha, + (float2*)a, + lda, + (float2*)b, + ldb, + (float2*)&beta, + (float2*)c, + ldc)); +#endif + } +} + +void BlasConnector::gemm_cm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + zgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) { - zgemm_mth_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); -#endif - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + m, + n, + k, + (double2*)&alpha, + (double2*)a, + lda, + (double2*)b, + ldb, + (double2*)&beta, + (double2*)c, + ldc)); +#endif + } } // Symm and Hemm part. Only col-major is supported. diff --git a/source/module_base/kernels/dsp/dsp_connector.cpp b/source/module_base/kernels/dsp/dsp_connector.cpp new file mode 100644 index 0000000000..a3c5f6d897 --- /dev/null +++ b/source/module_base/kernels/dsp/dsp_connector.cpp @@ -0,0 +1,335 @@ +#include "dsp_connector.h" + +#include +#include + +extern "C" +{ +#define complex_double ignore_complex_double +#include // MTBLAS_TRANSPOSE etc +#undef complex_double +#include // gemm +} +namespace mtfunc +{ +void dspInitHandle(int id) +{ + mt_blas_init(id); + std::cout << " ** DSP inited on cluster " << id << " **" << std::endl; +} // Use this at the beginning of the program to start a dsp cluster + +void dspDestoryHandle(int id) +{ + hthread_dev_close(id); + std::cout << " ** DSP closed on cluster " << id << " **" << std::endl; +} // Close dsp cluster at the end + +MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) +{ + switch (blasTrans[0]) + { + case 'N': + case 'n': + return MtblasNoTrans; + case 'T': + case 't': + return MtblasTrans; + case 'C': + case 'c': + return MtblasConjTrans; + default: + std::cout << "Invalid BLAS transpose parameter!! Use default instead." << std::endl; + return MtblasNoTrans; + } +} // Used to convert normal transpost char to mtblas transpose flag + +void* malloc_ht(size_t bytes, int cluster_id) +{ + // std::cout << "MALLOC " << cluster_id; + void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW); + // std::cout << ptr << " SUCCEED" << std::endl;; + return ptr; +} + +// Used to replace original malloc + +void free_ht(void* ptr) +{ + // std::cout << "FREE " << ptr; + hthread_free(ptr); + // std::cout << " FREE SUCCEED" << std::endl; +} + +// Used to replace original free + +void sgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc, + int cluster_id) +{ + mtblas_sgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + *alpha, + a, + *lda, + b, + *ldb, + *beta, + c, + *ldc, + cluster_id); +} // zgemm that needn't malloc_ht or free_ht + +void dgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc, + int cluster_id) +{ + mtblas_dgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + *alpha, + a, + *lda, + b, + *ldb, + *beta, + c, + *ldc, + cluster_id); +} // cgemm that needn't malloc_ht or free_ht + +void zgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) +{ + mtblas_zgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + (const void*)alpha, + (const void*)a, + *lda, + (const void*)b, + *ldb, + (const void*)beta, + (void*)c, + *ldc, + cluster_id); +} // zgemm that needn't malloc_ht or free_ht + +void cgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) +{ + mtblas_cgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + (const void*)alpha, + (const void*)a, + *lda, + (const void*)b, + *ldb, + (const void*)beta, + (void*)c, + *ldc, + cluster_id); +} // cgemm that needn't malloc_ht or free_ht + +// Used to replace original free + +void sgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc, + int cluster_id) +{ + mt_hthread_sgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + *alpha, + a, + *lda, + b, + *ldb, + *beta, + c, + *ldc, + cluster_id); +} // zgemm that needn't malloc_ht or free_ht + +void dgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc, + int cluster_id) +{ + mt_hthread_dgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + *alpha, + a, + *lda, + b, + *ldb, + *beta, + c, + *ldc, + cluster_id); +} // cgemm that needn't malloc_ht or free_ht + +void zgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) +{ + std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *alp = *alpha; + std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *bet = *beta; + mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + alp, + a, + *lda, + b, + *ldb, + bet, + c, + *ldc, + cluster_id); + +} // zgemm that needn't malloc_ht or free_ht + +void cgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) +{ + std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *alp = *alpha; + std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *bet = *beta; + + mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + (const void*)alp, + (const void*)a, + *lda, + (const void*)b, + *ldb, + (const void*)bet, + (void*)c, + *ldc, + cluster_id); + + free_ht(alp); + free_ht(bet); +} // cgemm that needn't malloc_ht or free_ht +} // namespace mtfunc \ No newline at end of file diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h index ea0d17749e..bbda25f798 100644 --- a/source/module_base/kernels/dsp/dsp_connector.h +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -6,95 +6,157 @@ #include "module_base/module_device/memory_op.h" #include "module_hsolver/diag_comm_info.h" +namespace mtfunc +{ // Base dsp functions void dspInitHandle(int id); void dspDestoryHandle(int id); -void *malloc_ht(size_t bytes, int cluster_id); +void* malloc_ht(size_t bytes, int cluster_id); void free_ht(void* ptr); - // mtblas functions -void sgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, - const float *b, const int *ldb, const float *beta, - float *c, const int *ldc, int cluster_id); - -void dgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const double *alpha,const double *a, const int *lda, - const double *b, const int *ldb, const double *beta, - double *c, const int *ldc, int cluster_id); - -void zgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id); - -void cgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id); - - -void sgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, - const float *b, const int *ldb, const float *beta, - float *c, const int *ldc, int cluster_id); - -void dgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const double *alpha,const double *a, const int *lda, - const double *b, const int *ldb, const double *beta, - double *c, const int *ldc, int cluster_id); - -void zgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id); - -void cgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id); - -//#define zgemm_ zgemm_mt +void sgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc, + int cluster_id); + +void dgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc, + int cluster_id); + +void zgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + +void cgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + +void sgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc, + int cluster_id); + +void dgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc, + int cluster_id); + +void zgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + +void cgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + +// #define zgemm_ zgemm_mt // The next is dsp utils. It may be moved to other files if this file get too huge template -void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm){ +void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm) +{ - using syncmem_complex_op = base_device::memory::synchronize_memory_op; + using syncmem_complex_op + = base_device::memory::synchronize_memory_op; - auto* swap = new T[notconv * nbase_x]; + auto* swap = new T[notconv * nbase_x]; auto* target = new T[notconv * nbase_x]; syncmem_complex_op()(swap, hcc + nbase * nbase_x, notconv * nbase_x); if (base_device::get_current_precision(swap) == "single") { - MPI_Reduce(swap, - target, - notconv * nbase_x, - MPI_COMPLEX, - MPI_SUM, - 0, - diag_comm); + MPI_Reduce(swap, target, notconv * nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm); } else { - MPI_Reduce(swap, - target, - notconv * nbase_x, - MPI_DOUBLE_COMPLEX, - MPI_SUM, - 0, - diag_comm); + MPI_Reduce(swap, target, notconv * nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm); } syncmem_complex_op()(hcc + nbase * nbase_x, target, notconv * nbase_x); @@ -102,30 +164,18 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv if (base_device::get_current_precision(swap) == "single") { - MPI_Reduce(swap, - target, - notconv * nbase_x, - MPI_COMPLEX, - MPI_SUM, - 0, - diag_comm); + MPI_Reduce(swap, target, notconv * nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm); } else { - MPI_Reduce(swap, - target, - notconv * nbase_x, - MPI_DOUBLE_COMPLEX, - MPI_SUM, - 0, - diag_comm); + MPI_Reduce(swap, target, notconv * nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm); } syncmem_complex_op()(scc + nbase * nbase_x, target, notconv * nbase_x); delete[] swap; delete[] target; } - +} // namespace mtfunc #endif #endif \ No newline at end of file diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 525ecee89f..9af0ce5a79 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -340,9 +340,9 @@ struct resize_memory_op_mt { if (arr != nullptr) { - free_ht(arr); + mtfunc::free_ht(arr); } - arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK); + arr = (FPTYPE*)mtfunc::malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK); std::string record_string; if (record_in != nullptr) { @@ -365,7 +365,7 @@ struct delete_memory_op_mt { void operator()(FPTYPE* arr) { - free_ht(arr); + mtfunc::free_ht(arr); } }; diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index 549e41c93c..e365e12b5e 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -13,6 +13,12 @@ if (USE_ROCM) module_fft/fft_rocm.cpp ) endif() +if (USE_DSP) + list (APPEND FFT_SRC + module_fft/fft_dsp.cpp + module_fft/fft_dsp_float.cpp + pw_transform_k_dsp.cpp) +endif() list(APPEND objects pw_basis.cpp @@ -36,6 +42,12 @@ add_library( ${objects} ) +if (USE_DSP) +target_link_libraries(planewave PRIVATE +${MTBLAS_FFT_DIR}/libmtblas/lib/libmtfft.a) +target_compile_definitions( planewave PUBLIC +FFT_DAT_DIR="${MTBLAS_FFT_DIR}/datfile/mt_fft_blas.dat") +endif() if(ENABLE_COVERAGE) add_coverage(planewave) endif() diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index b64b6f4e00..b7c63fc9b1 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -7,166 +7,150 @@ namespace ModulePW template class FFT_BASE { -public: + public: + FFT_BASE() {}; + virtual ~FFT_BASE() {}; - FFT_BASE(){}; - virtual ~FFT_BASE(){}; - /** * @brief Initialize the fft parameters As virtual function. - * + * * The function is used to initialize the fft parameters. */ - virtual __attribute__((weak)) - void initfft(int nx_in, - int ny_in, - int nz_in, - int lixy_in, - int rixy_in, - int ns_in, - int nplane_in, - int nproc_in, - bool gamma_only_in, - bool xprime_in = true); - - virtual __attribute__((weak)) - void initfft(int nx_in, - int ny_in, - int nz_in); + virtual __attribute__((weak)) void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true); + + virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in); /** * @brief Setup the fft Plan and data As pure virtual function. - * + * * The function is set as pure virtual function.In order to * override the function in the derived class.In the derived * class, the function is used to setup the fft Plan and data. */ - virtual void setupFFT()=0; + virtual void setupFFT() = 0; /** * @brief Clean the fft Plan As pure virtual function. - * + * * The function is set as pure virtual function.In order to * override the function in the derived class.In the derived * class, the function is used to clean the fft Plan. */ - virtual void cleanFFT()=0; - + virtual void cleanFFT() = 0; + /** * @brief Clear the fft data As pure virtual function. - * + * * The function is set as pure virtual function.In order to * override the function in the derived class.In the derived * class, the function is used to clear the fft data. */ - virtual void clear()=0; - + virtual void clear() = 0; + + virtual void resource_handler(const int flag) const {}; /** * @brief Get the real space data in cpu-like fft - * + * * The function is used to get the real space data.While the * FFT_BASE is an abstract class,the function will be override, - * The attribute weak is used to avoid define the function. + * The attribute weak is used to avoid define the function. */ - virtual __attribute__((weak)) - FPTYPE* get_rspace_data() const; + virtual __attribute__((weak)) FPTYPE* get_rspace_data() const; - virtual __attribute__((weak)) - std::complex* get_auxr_data() const; + virtual __attribute__((weak)) std::complex* get_auxr_data() const; - virtual __attribute__((weak)) - std::complex* get_auxg_data() const; + virtual __attribute__((weak)) std::complex* get_auxg_data() const; /** * @brief Get the auxiliary real space data in 3D - * + * * The function is used to get the auxiliary real space data in 3D. * While the FFT_BASE is an abstract class,the function will be override, * The attribute weak is used to avoid define the function. */ - virtual __attribute__((weak)) - std::complex* get_auxr_3d_data() const; + virtual __attribute__((weak)) std::complex* get_auxr_3d_data() const; - //forward fft in x-y direction + // forward fft in x-y direction /** * @brief Forward FFT in x-y direction * @param in input data * @param out output data - * + * * This function performs the forward FFT in the x-y direction. * It involves two axes, x and y. The FFT is applied multiple times - * along the left and right boundaries in the primary direction(which is - * determined by the xprime flag).Notably, the Y axis operates in + * along the left and right boundaries in the primary direction(which is + * determined by the xprime flag).Notably, the Y axis operates in * "many-many-FFT" mode. */ - virtual __attribute__((weak)) - void fftxyfor(std::complex* in, - std::complex* out) const; + virtual __attribute__((weak)) void fftxyfor(std::complex* in, + std::complex* out) const; - virtual __attribute__((weak)) - void fftxybac(std::complex* in, - std::complex* out) const; + virtual __attribute__((weak)) void fftxybac(std::complex* in, + std::complex* out) const; /** * @brief Forward FFT in z direction * @param in input data * @param out output data - * + * * This function performs the forward FFT in the z direction. * It involves only one axis, z. The FFT is applied only once. * Notably, the Z axis operates in many FFT with nz*ns. */ - virtual __attribute__((weak)) - void fftzfor(std::complex* in, - std::complex* out) const; - - virtual __attribute__((weak)) - void fftzbac(std::complex* in, - std::complex* out) const; + virtual __attribute__((weak)) void fftzfor(std::complex* in, + std::complex* out) const; + + virtual __attribute__((weak)) void fftzbac(std::complex* in, + std::complex* out) const; /** * @brief Forward FFT in x-y direction with real to complex * @param in input data, real type * @param out output data, complex type - * - * This function performs the forward FFT in the x-y direction + * + * This function performs the forward FFT in the x-y direction * with real to complex.There is no difference between fftxyfor. */ - virtual __attribute__((weak)) - void fftxyr2c(FPTYPE* in, - std::complex* out) const; - - virtual __attribute__((weak)) - void fftxyc2r(std::complex* in, - FPTYPE* out) const; - + virtual __attribute__((weak)) void fftxyr2c(FPTYPE* in, + std::complex* out) const; + + virtual __attribute__((weak)) void fftxyc2r(std::complex* in, + FPTYPE* out) const; + /** * @brief Forward FFT in 3D * @param in input data * @param out output data - * + * * This function performs the forward FFT for gpu-like fft. * It involves three axes, x, y, and z. The FFT is applied multiple times * for fft3D_forward. */ - virtual __attribute__((weak)) - void fft3D_forward(std::complex* in, - std::complex* out) const; - - virtual __attribute__((weak)) - void fft3D_backward(std::complex* in, - std::complex* out) const; - -protected: - int nx=0; - int ny=0; - int nz=0; + virtual __attribute__((weak)) void fft3D_forward(std::complex* in, + std::complex* out) const; + + virtual __attribute__((weak)) void fft3D_backward(std::complex* in, + std::complex* out) const; + + protected: + int nx = 0; + int ny = 0; + int nz = 0; }; template FFT_BASE::FFT_BASE(); template FFT_BASE::FFT_BASE(); template FFT_BASE::~FFT_BASE(); template FFT_BASE::~FFT_BASE(); -} +} // namespace ModulePW #endif // FFT_BASE_H diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index c2718abf5d..7289e8ab02 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -1,17 +1,21 @@ -#include #include "fft_bundle.h" #include "module_base/module_device/device.h" #include "module_base/module_device/memory_op.h" +#include "module_base/tool_quit.h" + +#include #if defined(__CUDA) #include "fft_cuda.h" #endif #if defined(__ROCM) #include "fft_rocm.h" #endif - -template -std::unique_ptr make_unique(Args &&... args) +#if defined(__DSP) +#include "fft_dsp.h" +#endif +template +std::unique_ptr make_unique(Args&&... args) { return std::unique_ptr(new FFT_BASE(std::forward(args)...)); } @@ -22,208 +26,278 @@ FFT_Bundle::~FFT_Bundle() this->clear(); } -void FFT_Bundle::setfft(std::string device_in,std::string precision_in) +void FFT_Bundle::setfft(std::string device_in, std::string precision_in) { this->device = device_in; this->precision = precision_in; } -void FFT_Bundle::initfft(int nx_in, - int ny_in, - int nz_in, - int lixy_in, - int rixy_in, - int ns_in, - int nplane_in, - int nproc_in, - bool gamma_only_in, - bool xprime_in , +void FFT_Bundle::initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in, bool mpifft_in) { - assert(this->device=="cpu" || this->device=="gpu"); - assert(this->precision=="single" || this->precision=="double" || this->precision=="mixing"); + assert(this->device == "cpu" || this->device == "gpu" || this->device == "dsp"); + assert(this->precision == "single" || this->precision == "double" || this->precision == "mixing"); - if (this->precision=="single") + if (this->precision == "single") { - #if not defined (__ENABLE_FLOAT_FFTW) - if (this->device == "cpu"){ +#if not defined(__ENABLE_FLOAT_FFTW) + if (this->device == "cpu") + { float_define = false; } - #endif - #if defined(__CUDA) || defined (__ROCM) - if (this->device == "gpu"){ +#endif +#if defined(__CUDA) || defined(__ROCM) + if (this->device == "gpu") + { float_flag = float_define; } - #endif +#endif float_flag = float_define; double_flag = true; } - if (this->precision=="double") + if (this->precision == "double") { double_flag = true; } - - if (device=="cpu") +#if defined(__DSP) + if (device == "dsp") + { + if (float_flag) + { + ModuleBase::WARNING_QUIT("device", "now dsp fft is not supported for the float type"); + } + fft_double = make_unique>(); + fft_double->initfft(nx_in, ny_in, nz_in); + } +#endif + if (device == "cpu") { fft_float = make_unique>(this->fft_mode); fft_double = make_unique>(this->fft_mode); if (float_flag) { - fft_float->initfft(nx_in, - ny_in, - nz_in, - lixy_in, - rixy_in, - ns_in, - nplane_in, - nproc_in, - gamma_only_in, - xprime_in); + fft_float + ->initfft(nx_in, ny_in, nz_in, lixy_in, rixy_in, ns_in, nplane_in, nproc_in, gamma_only_in, xprime_in); } if (double_flag) { - fft_double->initfft(nx_in, - ny_in, - nz_in, - lixy_in, - rixy_in, - ns_in, - nplane_in, - nproc_in, - gamma_only_in, - xprime_in); + fft_double + ->initfft(nx_in, ny_in, nz_in, lixy_in, rixy_in, ns_in, nplane_in, nproc_in, gamma_only_in, xprime_in); } } - if (device=="gpu") + if (device == "gpu") { - #if defined(__ROCM) - fft_float = make_unique>(); - fft_float->initfft(nx_in,ny_in,nz_in); - fft_double = make_unique>(); - fft_double->initfft(nx_in,ny_in,nz_in); - #elif defined(__CUDA) - fft_float = make_unique>(); - fft_float->initfft(nx_in,ny_in,nz_in); - fft_double = make_unique>(); - fft_double->initfft(nx_in,ny_in,nz_in); - #endif +#if defined(__ROCM) + fft_float = make_unique>(); + fft_float->initfft(nx_in, ny_in, nz_in); + fft_double = make_unique>(); + fft_double->initfft(nx_in, ny_in, nz_in); +#elif defined(__CUDA) + fft_float = make_unique>(); + fft_float->initfft(nx_in, ny_in, nz_in); + fft_double = make_unique>(); + fft_double->initfft(nx_in, ny_in, nz_in); +#endif } - } void FFT_Bundle::setupFFT() { - if (double_flag){fft_double->setupFFT();} - if (float_flag) {fft_float->setupFFT();} + if (double_flag) + { + fft_double->setupFFT(); + } + if (float_flag) + { + fft_float->setupFFT(); + } } void FFT_Bundle::clearFFT() { - if (double_flag){fft_double->cleanFFT();} - if (float_flag) {fft_float->cleanFFT();} + if (double_flag) + { + fft_double->cleanFFT(); + } + if (float_flag) + { + fft_float->cleanFFT(); + } } void FFT_Bundle::clear() { this->clearFFT(); - if (double_flag){fft_double->clear();} - if (float_flag) {fft_float->clear();} -} - -template <> void -FFT_Bundle::fftxyfor(std::complex* in, - std::complex* out) -const {fft_float->fftxyfor(in,out);} -template <> void -FFT_Bundle::fftxyfor(std::complex* in, - std::complex* out) -const {fft_double->fftxyfor(in,out);} - - -template <> void -FFT_Bundle::fftzfor(std::complex* in, - std::complex* out) -const {fft_float->fftzfor(in,out);} -template <> void -FFT_Bundle::fftzfor(std::complex* in, - std::complex* out) -const {fft_double->fftzfor(in,out);} - -template <> void -FFT_Bundle::fftxybac(std::complex* in, - std::complex* out) -const {fft_float->fftxybac(in,out);} -template <> void -FFT_Bundle::fftxybac(std::complex* in, - std::complex* out) -const {fft_double->fftxybac(in,out);} - -template <> void -FFT_Bundle::fftzbac(std::complex* in, - std::complex* out) -const {fft_float->fftzbac(in,out);} -template <> void -FFT_Bundle::fftzbac(std::complex* in, - std::complex* out) -const {fft_double->fftzbac(in,out);} - -template <> void -FFT_Bundle::fftxyr2c(float* in, - std::complex* out) -const {fft_float->fftxyr2c(in,out);} -template <> void -FFT_Bundle::fftxyr2c(double* in, - std::complex* out) -const {fft_double->fftxyr2c(in,out);} - -template <> void -FFT_Bundle::fftxyc2r(std::complex* in, - float* out) -const {fft_float->fftxyc2r(in,out);} -template <> void -FFT_Bundle::fftxyc2r(std::complex* in, - double* out) -const {fft_double->fftxyc2r(in,out);} - -template <> void -FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, - std::complex* in, - std::complex* out) -const {fft_float->fft3D_forward(in, out);} -template <> void -FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, - std::complex* in, - std::complex* out) -const {fft_double->fft3D_forward(in, out);} - -template <> void -FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, - std::complex* in, - std::complex* out) -const {fft_float->fft3D_backward(in, out);} -template <> void -FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, - std::complex* in, - std::complex* out) -const {fft_double->fft3D_backward(in, out);} + if (double_flag) + { + fft_double->clear(); + } + if (float_flag) + { + fft_float->clear(); + } +} + +void FFT_Bundle::resource_handler(const int flag) const +{ + if (this->device=="dsp") + { + if (double_flag) + { + fft_double->resource_handler(flag); + } + if (float_flag) + { + fft_float->resource_handler(flag); + } + } +} +template <> +void FFT_Bundle::fftxyfor(std::complex* in, std::complex* out) const +{ + fft_float->fftxyfor(in, out); +} +template <> +void FFT_Bundle::fftxyfor(std::complex* in, std::complex* out) const +{ + fft_double->fftxyfor(in, out); +} + +template <> +void FFT_Bundle::fftzfor(std::complex* in, std::complex* out) const +{ + fft_float->fftzfor(in, out); +} +template <> +void FFT_Bundle::fftzfor(std::complex* in, std::complex* out) const +{ + fft_double->fftzfor(in, out); +} + +template <> +void FFT_Bundle::fftxybac(std::complex* in, std::complex* out) const +{ + fft_float->fftxybac(in, out); +} +template <> +void FFT_Bundle::fftxybac(std::complex* in, std::complex* out) const +{ + fft_double->fftxybac(in, out); +} + +template <> +void FFT_Bundle::fftzbac(std::complex* in, std::complex* out) const +{ + fft_float->fftzbac(in, out); +} +template <> +void FFT_Bundle::fftzbac(std::complex* in, std::complex* out) const +{ + fft_double->fftzbac(in, out); +} + +template <> +void FFT_Bundle::fftxyr2c(float* in, std::complex* out) const +{ + fft_float->fftxyr2c(in, out); +} +template <> +void FFT_Bundle::fftxyr2c(double* in, std::complex* out) const +{ + fft_double->fftxyr2c(in, out); +} + +template <> +void FFT_Bundle::fftxyc2r(std::complex* in, float* out) const +{ + fft_float->fftxyc2r(in, out); +} +template <> +void FFT_Bundle::fftxyc2r(std::complex* in, double* out) const +{ + fft_double->fftxyc2r(in, out); +} + +template <> +void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const +{ + fft_float->fft3D_forward(in, out); +} +template <> +void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const +{ + fft_double->fft3D_forward(in, out); +} + +template <> +void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const +{ + fft_float->fft3D_backward(in, out); +} +template <> +void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const +{ + fft_double->fft3D_backward(in, out); +} // access the real space data -template <> float* -FFT_Bundle::get_rspace_data() const {return fft_float->get_rspace_data();} -template <> double* -FFT_Bundle::get_rspace_data() const {return fft_double->get_rspace_data();} - -template <> std::complex* -FFT_Bundle::get_auxr_data() const {return fft_float->get_auxr_data();} -template <> std::complex* -FFT_Bundle::get_auxr_data() const {return fft_double->get_auxr_data();} - -template <> std::complex* -FFT_Bundle::get_auxg_data() const {return fft_float->get_auxg_data();} -template <> std::complex* -FFT_Bundle::get_auxg_data() const {return fft_double->get_auxg_data();} - -template <> std::complex* -FFT_Bundle::get_auxr_3d_data() const {return fft_float->get_auxr_3d_data();} -template <> std::complex* -FFT_Bundle::get_auxr_3d_data() const {return fft_double->get_auxr_3d_data();} -} \ No newline at end of file +template <> +float* FFT_Bundle::get_rspace_data() const +{ + return fft_float->get_rspace_data(); +} +template <> +double* FFT_Bundle::get_rspace_data() const +{ + return fft_double->get_rspace_data(); +} + +template <> +std::complex* FFT_Bundle::get_auxr_data() const +{ + return fft_float->get_auxr_data(); +} +template <> +std::complex* FFT_Bundle::get_auxr_data() const +{ + return fft_double->get_auxr_data(); +} + +template <> +std::complex* FFT_Bundle::get_auxg_data() const +{ + return fft_float->get_auxg_data(); +} +template <> +std::complex* FFT_Bundle::get_auxg_data() const +{ + return fft_double->get_auxg_data(); +} + +template <> +std::complex* FFT_Bundle::get_auxr_3d_data() const +{ + return fft_float->get_auxr_3d_data(); +} +template <> +std::complex* FFT_Bundle::get_auxr_3d_data() const +{ + return fft_double->get_auxr_3d_data(); +} +} // namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 71ce5192f3..1982a79a0c 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -1,215 +1,208 @@ #ifndef FFT_TEMP_H #define FFT_TEMP_H -#include #include "fft_base.h" #include "fft_cpu.h" + +#include namespace ModulePW { class FFT_Bundle { - public: - FFT_Bundle(){}; - ~FFT_Bundle(); - /** - * @brief Constructor with device and precision. - * @param device_in device type, cpu or gpu. - * @param precision_in precision type, single or double. - * - * the function will check the input device and precision, - * and set the device and precision. - */ - FFT_Bundle(std::string device_in,std::string precision_in) - :device(device_in),precision(precision_in){}; - - /** - * @brief Set device and precision. - * @param device_in device type, cpu or gpu. - * @param precision_in precision type, single or double. - * - * the function will check the input device and precision, - * and set the device and precision. - */ - void setfft(std::string device_in,std::string precision_in); + public: + FFT_Bundle() {}; + ~FFT_Bundle(); + /** + * @brief Constructor with device and precision. + * @param device_in device type, cpu or gpu. + * @param precision_in precision type, single or double. + * + * the function will check the input device and precision, + * and set the device and precision. + */ + FFT_Bundle(std::string device_in, std::string precision_in) : device(device_in), precision(precision_in) {}; + + /** + * @brief Set device and precision. + * @param device_in device type, cpu or gpu. + * @param precision_in precision type, single or double. + * + * the function will check the input device and precision, + * and set the device and precision. + */ + void setfft(std::string device_in, std::string precision_in); + + /** + * @brief Initialize the fft parameters. + * @param nx_in number of grid points in x direction. + * @param ny_in number of grid points in y direction. + * @param nz_in number of grid points in z direction. + * @param lixy_in the position of the left boundary + * in the x-y plane. + * @param rixy_in the position of the right boundary + * in the x-y plane. + * @param ns_in number of stick whcih is used in the + * Z direction. + * @param nplane_in number of x-y planes. + * @param nproc_in number of processors. + * @param gamma_only_in whether only gamma point is used. + * @param xprime_in whether xprime is used. + * + * the function will initialize the many-fft parameters + * Wheatley in cpu or gpu device. + */ + void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true, + bool mpifft_in = false); + + /** + * @brief Initialize the fft mode. + * @param fft_mode_in fft mode. + * + * the function will initialize the fft mode. + */ - /** - * @brief Initialize the fft parameters. - * @param nx_in number of grid points in x direction. - * @param ny_in number of grid points in y direction. - * @param nz_in number of grid points in z direction. - * @param lixy_in the position of the left boundary - * in the x-y plane. - * @param rixy_in the position of the right boundary - * in the x-y plane. - * @param ns_in number of stick whcih is used in the - * Z direction. - * @param nplane_in number of x-y planes. - * @param nproc_in number of processors. - * @param gamma_only_in whether only gamma point is used. - * @param xprime_in whether xprime is used. - * - * the function will initialize the many-fft parameters - * Wheatley in cpu or gpu device. - */ - void initfft(int nx_in, - int ny_in, - int nz_in, - int lixy_in, - int rixy_in, - int ns_in, - int nplane_in, - int nproc_in, - bool gamma_only_in, - bool xprime_in = true, - bool mpifft_in = false); - - /** - * @brief Initialize the fft mode. - * @param fft_mode_in fft mode. - * - * the function will initialize the fft mode. - */ + void initfftmode(int fft_mode_in) + { + this->fft_mode = fft_mode_in; + } - void initfftmode(int fft_mode_in){this->fft_mode = fft_mode_in;} + void setupFFT(); - void setupFFT(); + void clearFFT(); - void clearFFT(); - - void clear(); - - /** - * @brief Get the real space data. - * @return FPTYPE* the real space data. - * - * the function will return the real space data, - * which is used in the cpu-like fft. - */ - template - FPTYPE* get_rspace_data() const; - /** - * @brief Get the auxr data. - * @return std::complex* the auxr data. - * - * the function will return the auxr data, - * which is used in the cpu-like fft. - */ - template - std::complex* get_auxr_data() const; - /** - * @brief Get the auxg data. - * @return std::complex* the auxg data. - * - * the function will return the auxg data, - * which is used in the cpu-like fft. - */ - template - std::complex* get_auxg_data() const; - /** - * @brief Get the auxr 3d data. - * @return std::complex* the auxr 3d data. - * - * the function will return the auxr 3d data, - * which is used in the gpu-like fft. - */ - template - std::complex* get_auxr_3d_data() const; - - /** - * @brief Forward fft in z direction. - * @param in input data. - * @param out output data. - * - * The function will do the forward many fft in z direction, - * As an interface, the function will call the fftzfor in the - * accurate fft class. - * which is used in the cpu-like fft. - */ - template - void fftzfor(std::complex* in, - std::complex* out) const; - /** - * @brief Forward fft in x-y direction. - * @param in input data. - * @param out output data. - * - * the function will do the forward fft in x and y direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftxyfor in the accurate fft class. - */ - template - void fftxyfor(std::complex* in, - std::complex* out) const; - /** - * @brief Backward fft in z direction. - * @param in input data. - * @param out output data. - * - * the function will do the backward many fft in z direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftzbac in the accurate fft class. - */ - template - void fftzbac(std::complex* in, - std::complex* out) const; - /** - * @brief Backward fft in x-y direction. - * @param in input data. - * @param out output data. - * - * the function will do the backward fft in x and y direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftxybac in the accurate fft class. - */ - template - void fftxybac(std::complex* in, - std::complex* out) const; - - /** - * @brief Real to complex fft in x-y direction. - * @param in input data. - * @param out output data. - * - * the function will do the real to complex fft in x and y direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftxyr2c in the accurate fft class. - */ - template - void fftxyr2c(FPTYPE* in, - std::complex* out) const; - /** - * @brief Complex to real fft in x-y direction. - * @param in input data. - * @param out output data. - * - * the function will do the complex to real fft in x and y direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftxyc2r in the accurate fft class. - */ - template - void fftxyc2r(std::complex* in, - FPTYPE* out) const; + void clear(); - template - void fft3D_forward(const Device* ctx, - std::complex* in, - std::complex* out) const; - template - void fft3D_backward(const Device* ctx, - std::complex* in, - std::complex* out) const; + void resource_handler(const int flag) const; + /** + * @brief Get the real space data. + * @return FPTYPE* the real space data. + * + * the function will return the real space data, + * which is used in the cpu-like fft. + */ + template + FPTYPE* get_rspace_data() const; + /** + * @brief Get the auxr data. + * @return std::complex* the auxr data. + * + * the function will return the auxr data, + * which is used in the cpu-like fft. + */ + template + std::complex* get_auxr_data() const; + /** + * @brief Get the auxg data. + * @return std::complex* the auxg data. + * + * the function will return the auxg data, + * which is used in the cpu-like fft. + */ + template + std::complex* get_auxg_data() const; + /** + * @brief Get the auxr 3d data. + * @return std::complex* the auxr 3d data. + * + * the function will return the auxr 3d data, + * which is used in the gpu-like fft. + */ + template + std::complex* get_auxr_3d_data() const; - private: - int fft_mode = 0; - bool float_flag=false; - bool float_define=true; - bool double_flag=false; - std::shared_ptr> fft_float=nullptr; - std::shared_ptr> fft_double=nullptr; - - std::string device = "cpu"; - std::string precision = "double"; -}; + /** + * @brief Forward fft in z direction. + * @param in input data. + * @param out output data. + * + * The function will do the forward many fft in z direction, + * As an interface, the function will call the fftzfor in the + * accurate fft class. + * which is used in the cpu-like fft. + */ + template + void fftzfor(std::complex* in, std::complex* out) const; + /** + * @brief Forward fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the forward fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyfor in the accurate fft class. + */ + template + void fftxyfor(std::complex* in, std::complex* out) const; + /** + * @brief Backward fft in z direction. + * @param in input data. + * @param out output data. + * + * the function will do the backward many fft in z direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftzbac in the accurate fft class. + */ + template + void fftzbac(std::complex* in, std::complex* out) const; + /** + * @brief Backward fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the backward fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxybac in the accurate fft class. + */ + template + void fftxybac(std::complex* in, std::complex* out) const; + + /** + * @brief Real to complex fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the real to complex fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyr2c in the accurate fft class. + */ + template + void fftxyr2c(FPTYPE* in, std::complex* out) const; + /** + * @brief Complex to real fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the complex to real fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyc2r in the accurate fft class. + */ + template + void fftxyc2r(std::complex* in, FPTYPE* out) const; + + template + void fft3D_forward(const Device* ctx, std::complex* in, std::complex* out) const; + template + void fft3D_backward(const Device* ctx, std::complex* in, std::complex* out) const; + + private: + int fft_mode = 0; + bool float_flag = false; + bool float_define = true; + bool double_flag = false; + std::shared_ptr> fft_float = nullptr; + std::shared_ptr> fft_double = nullptr; + + std::string device = "cpu"; + std::string precision = "double"; +}; } // namespace ModulePW #endif // FFT_H - diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp new file mode 100644 index 0000000000..0247ac84a7 --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -0,0 +1,125 @@ +#include "fft_dsp.h" + +#include "module_base/global_variable.h" + +#include +#include +#include +namespace ModulePW +{ +template <> +void FFT_DSP::initfft(int nx_in, int ny_in, int nz_in) +{ + this->nx = nx_in; + this->ny = ny_in; + this->nz = nz_in; + cluster_id = GlobalV::MY_RANK; + nxyz = this->nx * this->ny * this->nz; +} +template <> +void FFT_DSP::setupFFT() +{ + PROBLEM pbm_forward; + PROBLEM pbm_backward; + PLAN* ptr_plan_forward; + PLAN* ptr_plan_backward; + INT num_thread = 8; + INT size=0; + hthread_dat_load(cluster_id, FFT_DAT_DIR); + + // compute the size of and malloc thread + size = nx * ny * nz * 2 * sizeof(E); + forward_in = (E*)hthread_malloc((int)cluster_id, size, HT_MEM_RW); + + //init 3d fft problem + pbm_forward.num_dim = 3; + pbm_forward.n[0] = nx; + pbm_forward.n[1] = ny; + pbm_forward.n[2] = nz; + pbm_forward.iFFT = 0; + pbm_forward.in = forward_in; + pbm_forward.out = forward_in; + + //make ptr plan + make_plan(&pbm_forward, &ptr_plan_forward, cluster_id, num_thread); + ptr_plan_forward->in = forward_in; + ptr_plan_forward->out = forward_in; + args_for[1] = (unsigned long)ptr_plan_forward; + + // init 3d fft problem + pbm_backward.num_dim = 3; + pbm_backward.n[0] = nx; + pbm_backward.n[1] = ny; + pbm_backward.n[2] = nz; + pbm_backward.iFFT = 1; + pbm_backward.in = forward_in; + pbm_backward.out = forward_in; + + make_plan(&pbm_backward, &ptr_plan_backward, cluster_id, num_thread); + ptr_plan_backward->in = forward_in; + ptr_plan_backward->out = forward_in; + args_back[1] = (unsigned long)ptr_plan_backward; +} +template <> +void FFT_DSP::resource_handler(const int flag) const +{ + if (flag==0) + { + hthread_barrier_destroy(b_id); + hthread_group_destroy(thread_id_for); + } + else if (flag==1) + { + INT num_thread = 8; + thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); + // create b_id for the barrier + b_id = hthread_barrier_create(cluster_id); + args_for[0] = b_id; + args_back[0] = b_id; + } +} +template <> +void FFT_DSP::fft3D_forward(std::complex* in, std::complex* out) const +{ + hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for); + hthread_group_wait(thread_id_for); +} + +template <> +void FFT_DSP::fft3D_backward(std::complex* in, std::complex* out) const +{ + hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back); + hthread_group_wait(thread_id_for); +} +template <> +void FFT_DSP::cleanFFT() +{ + if (ptr_plan_forward != nullptr) + { + destroy_plan(ptr_plan_forward); + ptr_plan_forward = nullptr; + } + if (ptr_plan_backward != nullptr) + { + destroy_plan(ptr_plan_backward); + ptr_plan_backward = nullptr; + } +} + +template <> +void FFT_DSP::clear() +{ + this->cleanFFT(); + hthread_free(forward_in); +} + +template <> +std::complex* FFT_DSP::get_auxr_3d_data() const +{ + return reinterpret_cast*>(this->forward_in); +} +template FFT_DSP::FFT_DSP(); +template FFT_DSP::~FFT_DSP(); +template FFT_DSP::FFT_DSP(); +template FFT_DSP::~FFT_DSP(); +} // namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.h b/source/module_basis/module_pw/module_fft/fft_dsp.h new file mode 100644 index 0000000000..0cdfe84fc6 --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_dsp.h @@ -0,0 +1,82 @@ +#ifndef FFT_DSP_H +#define FFT_DSP_H + +#include "fft_base.h" +#include +#include +#include + +#include "hthread_host.h" +#include "mtfft.h" +#include "fftw3.h" + +namespace ModulePW +{ +template +class FFT_DSP : public FFT_BASE +{ + public: + FFT_DSP(){}; + ~FFT_DSP(){}; + + void setupFFT() override; + + void clear() override; + + void cleanFFT() override; + + /** + * @brief Initialize the fft parameters + * @param nx_in number of grid points in x direction + * @param ny_in number of grid points in y direction + * @param nz_in number of grid points in z direction + * + */ + virtual __attribute__((weak)) + void initfft(int nx_in, + int ny_in, + int nz_in) override; + + /** + * @brief Get the real space data + * @return real space data + */ + virtual __attribute__((weak)) + std::complex* get_auxr_3d_data() const override; + + /** + * @brief Forward FFT in 3D + * @param in input data, complex FPTYPE + * @param out output data, complex FPTYPE + * + * This function performs the forward FFT in 3D. + */ + virtual __attribute__((weak)) + void fft3D_forward(std::complex* in, + std::complex* out) const override; + /** + * @brief Backward FFT in 3D + * @param in input data, complex FPTYPE + * @param out output data, complex FPTYPE + * + * This function performs the backward FFT in 3D. + */ + virtual __attribute__((weak)) + void fft3D_backward(std::complex* in, + std::complex* out) const override; + public: + int nxyz=0; + INT cluster_id=0; + mutable INT b_id=0; + mutable INT thread_id_for=0; + PLAN* ptr_plan_forward=nullptr; + PLAN* ptr_plan_backward=nullptr; + mutable unsigned long args_for[2]; + mutable unsigned long args_back[2]; + E * forward_in=nullptr; + std::complex* c_auxr_3d = nullptr; // fft space + std::complex* z_auxr_3d = nullptr; // fft space + +}; +} // namespace ModulePW +#endif \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp b/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp new file mode 100644 index 0000000000..3c11cfc81f --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp @@ -0,0 +1,25 @@ +#include "fft_dsp.h" +namespace ModulePW +{ + +template<> +void FFT_DSP::setupFFT() +{ + +} +template<> +void FFT_DSP::clear() +{ + +} +template<> +void FFT_DSP::cleanFFT() +{ + +} +template<> +void FFT_DSP::resource_handler(const int flag) const +{ + +} +} \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index 2e0f85372d..08391242ea 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -1,18 +1,18 @@ #include "pw_basis_k.h" -#include - #include "module_base/constants.h" #include "module_base/memory.h" #include "module_base/timer.h" #include "module_parameter/parameter.h" + +#include namespace ModulePW { PW_Basis_K::PW_Basis_K() { - classname="PW_Basis_K"; - this->fft_bundle.setfft(this->device,this->precision); + classname = "PW_Basis_K"; + this->fft_bundle.setfft(this->device, this->precision); } PW_Basis_K::~PW_Basis_K() { @@ -23,13 +23,16 @@ PW_Basis_K::~PW_Basis_K() delete[] igl2ig_k; delete[] gk2; #if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") { - if (this->precision == "single") { + if (this->device == "gpu") + { + if (this->precision == "single") + { delmem_sd_op()(this->s_kvec_c); delmem_sd_op()(this->s_gcar); delmem_sd_op()(this->s_gk2); } - else { + else + { delmem_dd_op()(this->d_gcar); delmem_dd_op()(this->d_gk2); } @@ -37,9 +40,11 @@ PW_Basis_K::~PW_Basis_K() delmem_int_op()(this->ig2ixyz_k); delmem_int_op()(this->d_igl2isz_k); } - else { + else + { #endif - if (this->precision == "single") { + if (this->precision == "single") + { delmem_sh_op()(this->s_kvec_c); delmem_sh_op()(this->s_gcar); delmem_sh_op()(this->s_gk2); @@ -50,68 +55,81 @@ PW_Basis_K::~PW_Basis_K() #endif } -void PW_Basis_K:: initparameters( - const bool gamma_only_in, - const double gk_ecut_in, - const int nks_in, //number of k points in this pool - const ModuleBase::Vector3 *kvec_d_in, // Direct coordinates of k points - const int distribution_type_in, - const bool xprime_in -) +void PW_Basis_K::initparameters(const bool gamma_only_in, + const double gk_ecut_in, + const int nks_in, // number of k points in this pool + const ModuleBase::Vector3* kvec_d_in, // Direct coordinates of k points + const int distribution_type_in, + const bool xprime_in) { this->nks = nks_in; - delete[] this->kvec_d; this->kvec_d = new ModuleBase::Vector3 [nks]; - delete[] this->kvec_c; this->kvec_c = new ModuleBase::Vector3 [nks]; + delete[] this->kvec_d; + this->kvec_d = new ModuleBase::Vector3[nks]; + delete[] this->kvec_c; + this->kvec_c = new ModuleBase::Vector3[nks]; double kmaxmod = 0; - for(int ik = 0 ; ik < this->nks ; ++ik) + for (int ik = 0; ik < this->nks; ++ik) { this->kvec_d[ik] = kvec_d_in[ik]; this->kvec_c[ik] = this->kvec_d[ik] * this->G; double kmod = sqrt(this->kvec_c[ik] * this->kvec_c[ik]); - if(kmod > kmaxmod) { kmaxmod = kmod; -} + if (kmod > kmaxmod) + { + kmaxmod = kmod; + } } - this->gk_ecut = gk_ecut_in/this->tpiba2; + this->gk_ecut = gk_ecut_in / this->tpiba2; this->ggecut = pow(sqrt(this->gk_ecut) + kmaxmod, 2); - if(this->ggecut > this->gridecut_lat) + if (this->ggecut > this->gridecut_lat) { this->ggecut = this->gridecut_lat; - this->gk_ecut = pow(sqrt(this->ggecut) - kmaxmod ,2); + this->gk_ecut = pow(sqrt(this->ggecut) - kmaxmod, 2); } this->gamma_only = gamma_only_in; - if(kmaxmod > 0) { this->gamma_only = false; //if it is not the gamma point, we do not use gamma_only -} + if (kmaxmod > 0) + { + this->gamma_only = false; // if it is not the gamma point, we do not use gamma_only + } this->xprime = xprime_in; this->fftny = this->ny; this->fftnx = this->nx; - if (this->gamma_only) + if (this->gamma_only) { - if(this->xprime) { this->fftnx = int(this->nx / 2) + 1; - } else { this->fftny = int(this->ny / 2) + 1; -} + if (this->xprime) + { + this->fftnx = int(this->nx / 2) + 1; + } + else + { + this->fftny = int(this->ny / 2) + 1; + } } this->fftnz = this->nz; this->fftnxy = this->fftnx * this->fftny; this->fftnxyz = this->fftnxy * this->fftnz; this->distribution_type = distribution_type_in; #if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") { - if (this->precision == "single") { + if (this->device == "gpu") + { + if (this->precision == "single") + { resmem_sd_op()(this->s_kvec_c, this->nks * 3); - castmem_d2s_h2d_op()(this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); + castmem_d2s_h2d_op()(this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); } resmem_dd_op()(this->d_kvec_c, this->nks * 3); - syncmem_d2d_h2d_op()(this->d_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); + syncmem_d2d_h2d_op()(this->d_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); } - else { + else + { #endif - if (this->precision == "single") { + if (this->precision == "single") + { resmem_sh_op()(this->s_kvec_c, this->nks * 3); - castmem_d2s_h2h_op()(this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); + castmem_d2s_h2h_op()(this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); } - this->d_kvec_c = reinterpret_cast(&this->kvec_c[0][0]); + this->d_kvec_c = reinterpret_cast(&this->kvec_c[0][0]); // There's no need to allocate double pointers while in a CPU environment. #if defined(__CUDA) || defined(__ROCM) } @@ -120,50 +138,59 @@ void PW_Basis_K:: initparameters( void PW_Basis_K::setupIndGk() { - //count npwk + // count npwk this->npwk_max = 0; - delete[] this->npwk; this->npwk = new int [this->nks]; + delete[] this->npwk; + this->npwk = new int[this->nks]; for (int ik = 0; ik < this->nks; ik++) { int ng = 0; - for (int ig = 0; ig < this->npw ; ig++) + for (int ig = 0; ig < this->npw; ig++) { - const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); + const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); if (gk2 <= this->gk_ecut) { ++ng; } } this->npwk[ik] = ng; - ModuleBase::CHECK_WARNING_QUIT((ng == 0), "pw_basis_k.cpp", PARAM.inp.calculation,"Current core has no plane waves! Please reduce the cores."); - if ( this->npwk_max < ng) + ModuleBase::CHECK_WARNING_QUIT((ng == 0), + "pw_basis_k.cpp", + PARAM.inp.calculation, + "Current core has no plane waves! Please reduce the cores."); + if (this->npwk_max < ng) { this->npwk_max = ng; } } - - //get igl2isz_k and igl2ig_k - if(this->npwk_max <= 0) { return; -} - delete[] igl2isz_k; this->igl2isz_k = new int [this->nks * this->npwk_max]; - delete[] igl2ig_k; this->igl2ig_k = new int [this->nks * this->npwk_max]; + // get igl2isz_k and igl2ig_k + if (this->npwk_max <= 0) + { + return; + } + + delete[] igl2isz_k; + this->igl2isz_k = new int[this->nks * this->npwk_max]; + delete[] igl2ig_k; + this->igl2ig_k = new int[this->nks * this->npwk_max]; for (int ik = 0; ik < this->nks; ik++) { int igl = 0; - for (int ig = 0; ig < this->npw ; ig++) + for (int ig = 0; ig < this->npw; ig++) { - const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); + const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); if (gk2 <= this->gk_ecut) { - this->igl2isz_k[ik*npwk_max + igl] = this->ig2isz[ig]; - this->igl2ig_k[ik*npwk_max + igl] = ig; + this->igl2isz_k[ik * npwk_max + igl] = this->ig2isz[ig]; + this->igl2ig_k[ik * npwk_max + igl] = ig; ++igl; } } } #if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") { + if (this->device == "gpu") + { resmem_int_op()(this->d_igl2isz_k, this->npwk_max * this->nks); syncmem_int_h2d_op()(this->d_igl2isz_k, this->igl2isz_k, this->npwk_max * this->nks); } @@ -172,7 +199,7 @@ void PW_Basis_K::setupIndGk() return; } -/// +/// /// distribute plane wave basis and real-space grids to different processors /// set up maps for fft and create arrays for MPI_Alltoall /// set up ffts @@ -185,11 +212,36 @@ void PW_Basis_K::setuptransform() this->getstartgr(); this->setupIndGk(); this->fft_bundle.clear(); - this->fft_bundle.setfft(this->device,this->precision); - if(this->xprime){ - this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - }else{ - this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); +#if defined(__DSP) + this->fft_bundle.setfft("dsp", this->precision); +#else + this->fft_bundle.setfft(this->device, this->precision); +#endif + if (this->xprime) + { + this->fft_bundle.initfft(this->nx, + this->ny, + this->nz, + this->lix, + this->rix, + this->nst, + this->nplane, + this->poolnproc, + this->gamma_only, + this->xprime); + } + else + { + this->fft_bundle.initfft(this->nx, + this->ny, + this->nz, + this->liy, + this->riy, + this->nst, + this->nplane, + this->poolnproc, + this->gamma_only, + this->xprime); } this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); @@ -200,8 +252,10 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h this->erf_ecut = erf_ecut_in; this->erf_height = erf_height_in; this->erf_sigma = erf_sigma_in; - if(this->npwk_max <= 0) { return; -} + if (this->npwk_max <= 0) + { + return; + } delete[] gk2; delete[] gcar; this->gk2 = new double[this->npwk_max * this->nks]; @@ -210,10 +264,10 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h ModuleBase::Memory::record("PW_B_K::gcar", sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks); ModuleBase::Vector3 f; - for(int ik = 0 ; ik < this->nks ; ++ik) + for (int ik = 0; ik < this->nks; ++ik) { ModuleBase::Vector3 kv = this->kvec_d[ik]; - for(int igl = 0 ; igl < this-> npwk[ik] ; ++igl) + for (int igl = 0; igl < this->npwk[ik]; ++igl) { int isz = this->igl2isz_k[ik * npwk_max + igl]; int iz = isz % this->nz; @@ -221,12 +275,18 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h int ixy = this->is2fftixy[is]; int ix = ixy / this->fftny; int iy = ixy % this->fftny; - if (ix >= int(this->nx/2) + 1) { ix -= this->nx; -} - if (iy >= int(this->ny/2) + 1) { iy -= this->ny; -} - if (iz >= int(this->nz/2) + 1) { iz -= this->nz; -} + if (ix >= int(this->nx / 2) + 1) + { + ix -= this->nx; + } + if (iy >= int(this->ny / 2) + 1) + { + iy -= this->ny; + } + if (iz >= int(this->nz / 2) + 1) + { + iz -= this->nz; + } f.x = ix; f.y = iy; f.z = iz; @@ -245,30 +305,42 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h } } #if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") { - if (this->precision == "single") { + if (this->device == "gpu") + { + if (this->precision == "single") + { resmem_sd_op()(this->s_gk2, this->npwk_max * this->nks); resmem_sd_op()(this->s_gcar, this->npwk_max * this->nks * 3); castmem_d2s_h2d_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); - castmem_d2s_h2d_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); + castmem_d2s_h2d_op()(this->s_gcar, + reinterpret_cast(&this->gcar[0][0]), + this->npwk_max * this->nks * 3); } - else { + else + { resmem_dd_op()(this->d_gk2, this->npwk_max * this->nks); resmem_dd_op()(this->d_gcar, this->npwk_max * this->nks * 3); syncmem_d2d_h2d_op()(this->d_gk2, this->gk2, this->npwk_max * this->nks); - syncmem_d2d_h2d_op()(this->d_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); + syncmem_d2d_h2d_op()(this->d_gcar, + reinterpret_cast(&this->gcar[0][0]), + this->npwk_max * this->nks * 3); } } - else { + else + { #endif - if (this->precision == "single") { + if (this->precision == "single") + { resmem_sh_op()(this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2"); resmem_sh_op()(this->s_gcar, this->npwk_max * this->nks * 3, "PW_B_K::s_gcar"); castmem_d2s_h2h_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); - castmem_d2s_h2h_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); + castmem_d2s_h2h_op()(this->s_gcar, + reinterpret_cast(&this->gcar[0][0]), + this->npwk_max * this->nks * 3); } - else { - this->d_gcar = reinterpret_cast(&this->gcar[0][0]); + else + { + this->d_gcar = reinterpret_cast(&this->gcar[0][0]); this->d_gk2 = this->gk2; } // There's no need to allocate double pointers while in a CPU environment. @@ -277,18 +349,25 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h #endif } -ModuleBase::Vector3 PW_Basis_K:: cal_GplusK_cartesian(const int ik, const int ig) const { +ModuleBase::Vector3 PW_Basis_K::cal_GplusK_cartesian(const int ik, const int ig) const +{ int isz = this->ig2isz[ig]; int iz = isz % this->nz; int is = isz / this->nz; int ix = this->is2fftixy[is] / this->fftny; int iy = this->is2fftixy[is] % this->fftny; - if (ix >= int(this->nx/2) + 1) { ix -= this->nx; -} - if (iy >= int(this->ny/2) + 1) { iy -= this->ny; -} - if (iz >= int(this->nz/2) + 1) { iz -= this->nz; -} + if (ix >= int(this->nx / 2) + 1) + { + ix -= this->nx; + } + if (iy >= int(this->ny / 2) + 1) + { + iy -= this->ny; + } + if (iz >= int(this->nz / 2) + 1) + { + iz -= this->nz; + } ModuleBase::Vector3 f; f.x = ix; f.y = iy; @@ -317,34 +396,34 @@ ModuleBase::Vector3 PW_Basis_K::getgdirect(const int ik, const int igl) return f; } - ModuleBase::Vector3 PW_Basis_K::getgpluskcar(const int ik, const int igl) const { - return this->gcar[ik * this->npwk_max + igl]+this->kvec_c[ik]; + return this->gcar[ik * this->npwk_max + igl] + this->kvec_c[ik]; } int& PW_Basis_K::getigl2isz(const int ik, const int igl) const { - return this->igl2isz_k[ik*this->npwk_max + igl]; + return this->igl2isz_k[ik * this->npwk_max + igl]; } int& PW_Basis_K::getigl2ig(const int ik, const int igl) const { - return this->igl2ig_k[ik*this->npwk_max + igl]; + return this->igl2ig_k[ik * this->npwk_max + igl]; } - void PW_Basis_K::get_ig2ixyz_k() { +#if not defined(__DSP) if (this->device != "gpu") { - //only GPU need to get ig2ixyz_k + // only GPU need to get ig2ixyz_k return; } - int * ig2ixyz_k_cpu = new int [this->npwk_max * this->nks]; +#endif + ig2ixyz_k_cpu.resize(this->npwk_max * this->nks); ModuleBase::Memory::record("PW_B_K::ig2ixyz", sizeof(int) * this->npwk_max * this->nks); - assert(gamma_only == false); //We only finish non-gamma_only fft on GPU temperarily. - for(int ik = 0; ik < this->nks; ++ik) + assert(gamma_only == false); // We only finish non-gamma_only fft on GPU temperarily. + for (int ik = 0; ik < this->nks; ++ik) { - for(int igl = 0; igl < this->npwk[ik]; ++igl) + for (int igl = 0; igl < this->npwk[ik]; ++igl) { int isz = this->igl2isz_k[igl + ik * npwk_max]; int iz = isz % this->nz; @@ -356,8 +435,7 @@ void PW_Basis_K::get_ig2ixyz_k() } } resmem_int_op()(ig2ixyz_k, this->npwk_max * this->nks); - syncmem_int_h2d_op()(this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks); - delete[] ig2ixyz_k_cpu; + syncmem_int_h2d_op()(this->ig2ixyz_k, ig2ixyz_k_cpu.data(), this->npwk_max * this->nks); } std::vector PW_Basis_K::get_ig2ix(const int ik) const @@ -365,14 +443,16 @@ std::vector PW_Basis_K::get_ig2ix(const int ik) const std::vector ig_to_ix; ig_to_ix.resize(npwk[ik]); - for(int ig = 0; ig < npwk[ik]; ig++) + for (int ig = 0; ig < npwk[ik]; ig++) { int isz = this->igl2isz_k[ig + ik * npwk_max]; int is = isz / this->nz; int ixy = this->is2fftixy[is]; int ix = ixy / this->ny; - if (ix < (nx / 2) + 1) { ix += nx; -} + if (ix < (nx / 2) + 1) + { + ix += nx; + } ig_to_ix[ig] = ix; } return ig_to_ix; @@ -383,14 +463,16 @@ std::vector PW_Basis_K::get_ig2iy(const int ik) const std::vector ig_to_iy; ig_to_iy.resize(npwk[ik]); - for(int ig = 0; ig < npwk[ik]; ig++) + for (int ig = 0; ig < npwk[ik]; ig++) { int isz = this->igl2isz_k[ig + ik * npwk_max]; int is = isz / this->nz; int ixy = this->is2fftixy[is]; int iy = ixy % this->ny; - if (iy < (ny / 2) + 1) { iy += ny; -} + if (iy < (ny / 2) + 1) + { + iy += ny; + } ig_to_iy[ig] = iy; } return ig_to_iy; @@ -401,42 +483,50 @@ std::vector PW_Basis_K::get_ig2iz(const int ik) const std::vector ig_to_iz; ig_to_iz.resize(npwk[ik]); - for(int ig = 0; ig < npwk[ik]; ig++) + for (int ig = 0; ig < npwk[ik]; ig++) { int isz = this->igl2isz_k[ig + ik * npwk_max]; int iz = isz % this->nz; - if (iz < (nz / 2) + 1) { iz += nz; -} + if (iz < (nz / 2) + 1) + { + iz += nz; + } ig_to_iz[ig] = iz; } return ig_to_iz; } template <> -float * PW_Basis_K::get_kvec_c_data() const { +float* PW_Basis_K::get_kvec_c_data() const +{ return this->s_kvec_c; } template <> -double * PW_Basis_K::get_kvec_c_data() const { +double* PW_Basis_K::get_kvec_c_data() const +{ return this->d_kvec_c; } template <> -float * PW_Basis_K::get_gcar_data() const { +float* PW_Basis_K::get_gcar_data() const +{ return this->s_gcar; } template <> -double * PW_Basis_K::get_gcar_data() const { +double* PW_Basis_K::get_gcar_data() const +{ return this->d_gcar; } template <> -float * PW_Basis_K::get_gk2_data() const { +float* PW_Basis_K::get_gk2_data() const +{ return this->s_gk2; } template <> -double * PW_Basis_K::get_gk2_data() const { +double* PW_Basis_K::get_gk2_data() const +{ return this->d_gk2; } -} // namespace ModulePW \ No newline at end of file +} // namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_basis_k.h b/source/module_basis/module_pw/pw_basis_k.h index f5be09cfbd..ae5076bba9 100644 --- a/source/module_basis/module_pw/pw_basis_k.h +++ b/source/module_basis/module_pw/pw_basis_k.h @@ -87,7 +87,7 @@ class PW_Basis_K : public PW_Basis int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz) int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz - + std::vector ig2ixyz_k_cpu; /// [npw] map ig to ixyz,which is used in dsp fft. double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks] // liuyu add 2023-09-06 @@ -135,6 +135,31 @@ class PW_Basis_K : public PW_Basis const int ik, const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + #if defined(__DSP) + template + void convolution(const Device* ctx, + const int ik, + const int size, + const std::complex* input, + const FPTYPE* input1, + std::complex* output, + const bool add = false, + const FPTYPE factor =1.0) const ; + + template + void real2recip_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) + template + void recip2real_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + + #endif template void real_to_recip(const Device* ctx, diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index e230066c8f..3d75f07f6f 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -307,7 +307,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/, const bool add, const double factor) const { - this->real2recip(in, out, ik, add, factor); + #if defined(__DSP) + this->real2recip_dsp(in,out,ik,add,factor); + #else + this->real2recip(in, out, ik, add, factor); + #endif } template <> @@ -328,7 +332,11 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, const bool add, const double factor) const { - this->recip2real(in, out, ik, add, factor); + #if defined(__DSP) + this->recip2real_dsp(in,out,ik,add,factor); + #else + this->recip2real(in, out, ik, add, factor); + #endif } #if (defined(__CUDA) || defined(__ROCM)) diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp new file mode 100644 index 0000000000..b292e25f0a --- /dev/null +++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp @@ -0,0 +1,156 @@ +#include "module_base/timer.h" +#include "module_basis/module_pw/kernels/pw_op.h" +#include "pw_basis_k.h" +#include "pw_gatherscatter.h" + +#include +#include +#if defined (__DSP) +namespace ModulePW +{ +template +void PW_Basis_K::real2recip_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const FPTYPE factor) const +{ + const base_device::DEVICE_CPU* ctx; + const base_device::DEVICE_GPU* gpux; + assert(this->gamma_only == false); + auto* auxr = this->fft_bundle.get_auxr_3d_data(); + + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + // copy the in into the auxr with complex + memcpy(auxr, in, this->nrxx * 2 * 8); + + // 3d fft + this->fft_bundle.resource_handler(1); + this->fft_bundle.fft3D_forward(gpux, + auxr, + auxr); + this->fft_bundle.resource_handler(0); + // copy the result from the auxr to the out ,while consider the add + set_real_to_recip_output_op()(ctx, + npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k_cpu.data() + startig, + auxr, + out); +} +template +void PW_Basis_K::recip2real_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const FPTYPE factor) const +{ + assert(this->gamma_only == false); + const base_device::DEVICE_CPU* ctx; + const base_device::DEVICE_GPU* gpux; + // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz + auto* auxr = this->fft_bundle.get_auxr_3d_data(); + memset(auxr, 0, this->nxyz * 2 * 8); + + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + // copy the mapping form the type of stick to the 3dfft + set_3d_fft_box_op()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr); + // use 3d fft backward + this->fft_bundle.resource_handler(1); + this->fft_bundle.fft3D_backward(gpux, auxr, auxr); + this->fft_bundle.resource_handler(0); + if (add) + { + const int one = 1; + const std::complex factor1 = std::complex(factor, 0); + zaxpy_(&nrxx, &factor1, auxr, &one, out, &one); + } + else + { + memcpy(out, auxr, nrxx * 2 * 8); + } +} +template <> +void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, + const int ik, + const int size, + const std::complex* input, + const float* input1, + std::complex* output, + const bool add, + const float factor) const +{ +} + +template <> +void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, + const int ik, + const int size, + const std::complex* input, + const double* input1, + std::complex* output, + const bool add, + const double factor) const +{ + ModuleBase::timer::tick(this->classname, "convolution"); + + assert(this->gamma_only == false); + const base_device::DEVICE_GPU* gpux; + // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz + auto* auxr = this->fft_bundle.get_auxr_3d_data(); + memset(auxr, 0, this->nxyz * 2 * 8); + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + + // copy the mapping form the type of stick to the 3dfft + set_3d_fft_box_op()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr); + + // use 3d fft backward + this->fft_bundle.fft3D_backward(gpux, auxr, auxr); + + for (int ir = 0; ir < size; ir++) + { + auxr[ir] *= input1[ir]; + } + + // 3d fft + this->fft_bundle.fft3D_forward(gpux, auxr, auxr); + // copy the result from the auxr to the out ,while consider the add + set_real_to_recip_output_op()(ctx, + npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k_cpu.data() + startig, + auxr, + output); + ModuleBase::timer::tick(this->classname, "convolution"); +} + +// template void PW_Basis_K::real2recip_dsp(const std::complex* in, +// std::complex* out, +// const int ik, +// const bool add, +// const float factor) const; // in:(nplane,nx*ny) ; out(nz, ns) +// template void PW_Basis_K::recip2real_dsp(const std::complex* in, +// std::complex* out, +// const int ik, +// const bool add, +// const float factor) const; // in:(nz, ns) ; out(nplane,nx*ny) + +template void PW_Basis_K::real2recip_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const double factor) const; // in:(nplane,nx*ny) ; out(nz, ns) +template void PW_Basis_K::recip2real_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const double factor) const; +} // namespace ModulePW +#endif diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 0961675029..74d3904a65 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -81,7 +81,7 @@ ESolver_KS_PW::ESolver_KS_PW() #endif #ifdef __DSP std::cout << " ** Initializing DSP Hardware..." << std::endl; - dspInitHandle(GlobalV::MY_RANK); + mtfunc::dspInitHandle(GlobalV::MY_RANK); #endif } @@ -109,7 +109,7 @@ ESolver_KS_PW::~ESolver_KS_PW() } #ifdef __DSP std::cout << " ** Closing DSP Hardware..." << std::endl; - dspDestoryHandle(GlobalV::MY_RANK); + mtfunc::dspDestoryHandle(GlobalV::MY_RANK); #endif if(PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp index 6bff6b2dc0..54e1a052be 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp @@ -53,7 +53,9 @@ void Veff>::act( int max_npw = nbasis / npol; const int current_spin = this->isk[this->ik]; - +#ifdef __DSP + wfcpw->fft_bundle.resource_handler(1); +#endif // T *porter = new T[wfcpw->nmaxgr]; for (int ib = 0; ib < nbands; ib += npol) { @@ -75,6 +77,13 @@ void Veff>::act( } // wfcpw->real2recip(porter, tmhpsi, this->ik, true); wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true); + // wfcpw->convolution(this->ctx, + // this->ik, + // this->veff_col, + // tmpsi_in, + // this->veff+current_spin, + // tmhpsi, + // true); } else { @@ -111,6 +120,9 @@ void Veff>::act( tmhpsi += max_npw * npol; tmpsi_in += max_npw * npol; } +#ifdef __DSP + wfcpw->fft_bundle.resource_handler(0); +#endif ModuleBase::timer::tick("Operator", "VeffPW"); } diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index 177e68847c..2d1b747de4 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -454,7 +454,7 @@ void Diago_DavSubspace::cal_elem(const int& dim, { #ifdef __DSP // Only on dsp hardware need an extra space to reduce data - dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm); + mtfunc::dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm); #else auto* swap = new T[notconv * this->nbase_x];