Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ if (USE_DSP)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a)
endif()

target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR})

if (USE_SW)
add_compile_definitions(__SW)
set(SW ON)
Expand All @@ -295,6 +298,7 @@ if (USE_SW)
target_link_libraries(${ABACUS_BIN_NAME} ${SW_MATH}/libswblas.a)
endif()


find_package(Threads REQUIRED)
target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads)

Expand Down
44 changes: 25 additions & 19 deletions source/source_base/kernels/dsp/dsp_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,33 @@ extern "C"
}
namespace mtfunc
{
std::complex<double>* gemm_alpha_double=nullptr;
std::complex<double>* gemm_beta_double=nullptr;
std::complex<float>* gemm_alpha_float=nullptr;
std::complex<float>* gemm_beta_float=nullptr;

void dspInitHandle(int id)
{
mt_blas_init(id);
std::cout << " ** DSP inited on cluster " << id << " **" << std::endl;
mtfunc::gemm_alpha_double=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), id);
mtfunc::gemm_beta_double=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), id);
mtfunc::gemm_alpha_float=(std::complex<float>*)mtfunc::malloc_ht(sizeof(std::complex<float>), id);
mtfunc::gemm_beta_float=(std::complex<float>*)mtfunc::malloc_ht(sizeof(std::complex<float>), id);
} // Use this at the beginning of the program to start a dsp cluster

void dspDestoryHandle(int id)
{
hthread_dev_close(id);
std::cout << " ** DSP closed on cluster " << id << " **" << std::endl;
mtfunc::free_ht(mtfunc::gemm_alpha_double);
mtfunc::free_ht(mtfunc::gemm_beta_double);
mtfunc::free_ht(mtfunc::gemm_alpha_float);
mtfunc::free_ht(mtfunc::gemm_beta_float);
mtfunc::gemm_alpha_double = nullptr;
mtfunc::gemm_beta_double = nullptr;
mtfunc::gemm_alpha_float = nullptr;
mtfunc::gemm_beta_float = nullptr;
} // Close dsp cluster at the end

MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans)
Expand All @@ -45,19 +62,15 @@ MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans)

void* malloc_ht(size_t bytes, int cluster_id)
{
// std::cout << "MALLOC " << cluster_id;
void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW);
// std::cout << ptr << " SUCCEED" << std::endl;;
return ptr;
}

// Used to replace original malloc

void free_ht(void* ptr)
{
// std::cout << "FREE " << ptr;
hthread_free(ptr);
// std::cout << " FREE SUCCEED" << std::endl;
}

// Used to replace original free
Expand Down Expand Up @@ -271,22 +284,20 @@ void zgemm_mth_(const char* transa,
const int* ldc,
int cluster_id)
{
std::complex<double>* alp = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*alp = *alpha;
std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*bet = *beta;
*gemm_alpha_double = *alpha;
*gemm_beta_double = *beta;
mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
*n,
*k,
alp,
gemm_alpha_double,
a,
*lda,
b,
*ldb,
bet,
gemm_beta_double,
c,
*ldc,
cluster_id);
Expand All @@ -308,28 +319,23 @@ void cgemm_mth_(const char* transa,
const int* ldc,
int cluster_id)
{
std::complex<float>* alp = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
*alp = *alpha;
std::complex<float>* bet = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
*bet = *beta;
gemm_alpha_float = alpha;
gemm_beta_float = beta;

mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
*n,
*k,
(const void*)alp,
(const void*)gemm_alpha_float,
(const void*)a,
*lda,
(const void*)b,
*ldb,
(const void*)bet,
(const void*)gemm_beta_float,
(void*)c,
*ldc,
cluster_id);

free_ht(alp);
free_ht(bet);
} // cgemm that needn't malloc_ht or free_ht
} // namespace mtfunc
5 changes: 4 additions & 1 deletion source/source_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ void* malloc_ht(size_t bytes, int cluster_id);
void free_ht(void* ptr);

// mtblas functions

extern std::complex<double>* gemm_alpha_double;
extern std::complex<double>* gemm_beta_double;
extern std::complex<float>* gemm_alpha_float;
extern std::complex<float>* gemm_beta_float;
void sgemm_mt_(const char* transa,
const char* transb,
const int* m,
Expand Down
4 changes: 2 additions & 2 deletions source/source_basis/module_pw/kernels/pw_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct set_3d_fft_box_op {

template <typename FPTYPE, typename Device>
struct set_recip_to_real_output_op {
/// @brief Calculate the outputs after the FFT translation of recip_to_real
/// @brief Calculate the outputs after the FFT translation of recip2real_impl
///
/// Input Parameters
/// @param dev - which device this function runs on
Expand All @@ -54,7 +54,7 @@ struct set_recip_to_real_output_op {

template <typename FPTYPE, typename Device>
struct set_real_to_recip_output_op {
/// @brief Calculate the outputs after the FFT translation of real_to_recip
/// @brief Calculate the outputs after the FFT translation of real2recip_impl
///
/// Input Parameters
/// @param dev - which device this function runs on
Expand Down
2 changes: 1 addition & 1 deletion source/source_basis/module_pw/module_fft/fft_dsp.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "fft_dsp.h"

#include "source_base/global_variable.h"

#include "source_base/tool_quit.h"
#include <iostream>
#include <string.h>
#include <vector>
Expand Down
10 changes: 5 additions & 5 deletions source/source_basis/module_pw/pw_basis.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class PW_Basis
&& std::is_same<Device, base_device::DEVICE_CPU>::value,
int>::type
= 0>
void recip_to_real(TK* in, TR* out, const bool add = false, const typename GetTypeReal<TK>::type factor = 1.0) const
void recip2real_impl(TK* in, TR* out, const bool add = false, const typename GetTypeReal<TK>::type factor = 1.0) const
{
this->recip2real(in, out, add, factor);
};
Expand Down Expand Up @@ -350,7 +350,7 @@ class PW_Basis
&& std::is_same<Device, base_device::DEVICE_GPU>::value,
int>::type
= 0>
void recip_to_real(TK* in,
void recip2real_impl(TK* in,
TR* out,
const bool add = false,
const typename GetTypeReal<TK>::type factor = 1.0) const
Expand All @@ -361,7 +361,7 @@ class PW_Basis
// template <typename FPTYPE,
// typename Device,
// typename std::enable_if<!std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type = 0>
// void recip_to_real(FPTYPE* in,
// void recip2real_impl(FPTYPE* in,
// FPTYPE* out,
// const bool add = false,
// const typename GetTypeReal<FPTYPE>::type factor = 1.0) const;
Expand Down Expand Up @@ -391,7 +391,7 @@ class PW_Basis
typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
&& (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
&& std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0>
void real_to_recip(TR* in,
void real2recip_impl(TR* in,
TK* out,
const bool add = false,
const typename GetTypeReal<TK>::type factor = 1.0) const
Expand All @@ -403,7 +403,7 @@ class PW_Basis
typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
&& (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
&& std::is_same<Device, base_device::DEVICE_GPU>::value ,int>::type = 0>
void real_to_recip(TR* in,
void real2recip_impl(TR* in,
TK* out,
const bool add = false,
const typename GetTypeReal<TK>::type factor = 1.0) const
Expand Down
12 changes: 6 additions & 6 deletions source/source_basis/module_pw/pw_basis_k.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ class PW_Basis_K : public PW_Basis
#endif

template <typename FPTYPE, typename Device>
void real_to_recip(const Device* ctx,
void real2recip_impl(const Device* ctx,
const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out,
const int ik,
const bool add = false,
const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns)
template <typename FPTYPE, typename Device>
void recip_to_real(const Device* ctx,
void recip2real_impl(const Device* ctx,
const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out,
const int ik,
Expand All @@ -180,7 +180,7 @@ class PW_Basis_K : public PW_Basis
template <typename TK,
typename Device,
typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
void real_to_recip(const TK* in,
void real2recip_impl(const TK* in,
TK* out,
const int ik,
const bool add = false,
Expand All @@ -195,7 +195,7 @@ class PW_Basis_K : public PW_Basis
template <typename TK,
typename Device,
typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
void recip_to_real(const TK* in,
void recip2real_impl(const TK* in,
TK* out,
const int ik,
const bool add = false,
Expand Down Expand Up @@ -225,7 +225,7 @@ class PW_Basis_K : public PW_Basis
template <typename FPTYPE,
typename Device,
typename std::enable_if<!std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
void real_to_recip(const FPTYPE* in,
void real2recip_impl(const FPTYPE* in,
FPTYPE* out,
const int ik,
const bool add = false,
Expand All @@ -237,7 +237,7 @@ class PW_Basis_K : public PW_Basis
template <typename TK,
typename Device,
typename std::enable_if<std::is_same<Device, base_device::DEVICE_GPU>::value, int>::type = 0>
void recip_to_real(const TK* in,
void recip2real_impl(const TK* in,
TK* out,
const int ik,
const bool add = false,
Expand Down
16 changes: 8 additions & 8 deletions source/source_basis/module_pw/pw_transform_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace ModulePW
template <typename FPTYPE>
void PW_Basis::real2recip_gpu(const FPTYPE* in, std::complex<FPTYPE>* out, const bool add, const FPTYPE factor) const
{
ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
ModuleBase::timer::tick(this->classname, "real2recip_impl gpu");
assert(this->poolnproc == 1);
const size_t size = this->nrxx;
base_device::memory::cast_memory_op<std::complex<FPTYPE>, FPTYPE,base_device::DEVICE_GPU, base_device::DEVICE_GPU>()(
Expand All @@ -25,15 +25,15 @@ void PW_Basis::real2recip_gpu(const FPTYPE* in, std::complex<FPTYPE>* out, const
this->ig2ixyz_gpu,
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
out);
ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
ModuleBase::timer::tick(this->classname, "real2recip_impl gpu");
}
template <typename FPTYPE>
void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out,
const bool add,
const FPTYPE factor) const
{
ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
ModuleBase::timer::tick(this->classname, "real2recip_impl gpu");
assert(this->poolnproc == 1);
base_device::memory::synchronize_memory_op<std::complex<FPTYPE>,
base_device::DEVICE_GPU,
Expand All @@ -50,13 +50,13 @@ void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in,
this->ig2ixyz_gpu,
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
out);
ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
ModuleBase::timer::tick(this->classname, "real2recip_impl gpu");
}

template <typename FPTYPE>
void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in, FPTYPE* out, const bool add, const FPTYPE factor) const
{
ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
ModuleBase::timer::tick(this->classname, "recip2real_impl gpu");
assert(this->poolnproc == 1);
// ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<FPTYPE>(), this->nxyz);
base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
Expand All @@ -76,15 +76,15 @@ void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in, FPTYPE* out, const
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
out);

ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
ModuleBase::timer::tick(this->classname, "recip2real_impl gpu");
}
template <typename FPTYPE>
void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out,
const bool add,
const FPTYPE factor) const
{
ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
ModuleBase::timer::tick(this->classname, "recip2real_impl gpu");
assert(this->poolnproc == 1);
// ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<double>(), this->nxyz);
base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
Expand All @@ -105,7 +105,7 @@ void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in,
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
out);

ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
ModuleBase::timer::tick(this->classname, "recip2real_impl gpu");
}
template void PW_Basis::real2recip_gpu<double>(const double* in,
std::complex<double>* out,
Expand Down
Loading
Loading