Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ endif()
if (USE_DSP)
target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY})
add_compile_definitions(__DSP)
target_include_directories(${ABACUS_BIN_NAME} ${DIR_MTFFT_INCLUDES})
target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTFFT_LIBRARY})
target_include_directories(${ABACUS_BIN_NAME} ${DIR_HTRERAD_INLCUDES})
endif()

find_package(Threads REQUIRED)
Expand Down
5 changes: 5 additions & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ if (USE_ROCM)
module_fft/fft_rocm.cpp
)
endif()
if (USE_DSP)
list (APPEND FFT_SRC
module_fft/fft_dsp.cpp
pw_transform_k_dsp.cpp)
endif()

list(APPEND objects
pw_basis.cpp
Expand Down
11 changes: 10 additions & 1 deletion source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#if defined(__ROCM)
#include "fft_rocm.h"
#endif

#if defined(__DSP)
#include "fft_dsp.h"
#endif
template<typename FFT_BASE, typename... Args>
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
{
Expand Down Expand Up @@ -65,6 +67,12 @@ void FFT_Bundle::initfft(int nx_in,

if (device=="cpu")
{
#if defined(__DSP)
if (float_flag==true)
ModuleBase::WARNING_QUT("device","now dsp fft is not support for the float type");
fft_double=make_unique<FFT_DSP<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
#else
fft_float = make_unique<FFT_CPU<float>>(this->fft_mode);
fft_double = make_unique<FFT_CPU<double>>(this->fft_mode);
if (float_flag)
Expand Down Expand Up @@ -93,6 +101,7 @@ void FFT_Bundle::initfft(int nx_in,
gamma_only_in,
xprime_in);
}
#endif
}
if (device=="gpu")
{
Expand Down
120 changes: 120 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_dsp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include "fft_dsp.h"
#include "module_base/global_variable.h"
#include <string.h>
#include <iostream>
#include <vector>
namespace ModulePW
{
template<>
void FFT_DSP<double>::initfft(int nx_in,int ny_in,int nz_in)
{
this->nx=nx_in;
this->ny=ny_in;
this->nz=nz_in;
cluster_id = GlobalV::MY_RANK;
}
template<>
void FFT_DSP<double>::setupFFT()
{
PROBLEM pbm_forward;
PROBLEM pbm_backward;
PLAN* ptr_plan_forward;
PLAN* ptr_plan_backward;
INT num_thread=8;
INT size;
//open cluster id
hthread_dev_open(cluster_id);
//load mt.dat
hthread_dat_load(cluster_id, "mt_fft_device.dat");

thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL);
//create b_id for the barrier
b_id = hthread_barrier_create(cluster_id);
args_for[0] = b_id;

//compute the size of and malloc thread
size = nx*ny*nz*2*sizeof(E);
forward_in = (E*)hthread_malloc((int)cluster_id, size, HT_MEM_RW);

// //init 3d fft problem
pbm_forward.num_dim = 3;
pbm_forward.n[0] = nx;
pbm_forward.n[1] = ny;
pbm_forward.n[2] = nz;
pbm_forward.iFFT = 0;
pbm_forward.in = forward_in;
pbm_forward.out = forward_in;

// //make ptr plan
make_plan(&pbm_forward, &ptr_plan_forward, cluster_id, num_thread);
ptr_plan_forward->in = forward_in;
ptr_plan_forward->out = forward_in;
args_for[1] = (unsigned long)ptr_plan_forward;

//init 3d fft problem
pbm_backward.num_dim = 3; // dimensions of FFT
pbm_backward.n[0] = nx; // first dimension
pbm_backward.n[1] = ny; // second dimension
pbm_backward.n[2] = nz; // third dimension
pbm_backward.iFFT = 1; // 0 stand for forward,1 stand for backward
pbm_backward.in = forward_in; // the input data
pbm_backward.out = forward_in; // the output data

make_plan(&pbm_backward, &ptr_plan_backward, cluster_id, num_thread);
ptr_plan_backward->in = forward_in;
ptr_plan_backward->out = forward_in;
args_back[0]=b_id;
args_back[1]=(unsigned long)ptr_plan_backward;
}

template<>
void FFT_DSP<double>::fft3D_forward(std::complex<double>* in,
std::complex<double>* out) const
{
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for);
hthread_group_wait(thread_id_for);
}

template<>
void FFT_DSP<double>::fft3D_backward(std::complex<double> * in,
std::complex<double>* out) const
{
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back);
hthread_group_wait(thread_id_for);

}
template<>
void FFT_DSP<double>::cleanFFT()
{
if (ptr_plan_forward!=nullptr)
{
destroy_plan(ptr_plan_forward);
ptr_plan_forward=nullptr;
}
if (ptr_plan_backward!=nullptr)
{
destroy_plan(ptr_plan_backward);
ptr_plan_backward=nullptr;
}
}

template<>
void FFT_DSP<double>::clear()
{
this->cleanFFT();
hthread_free(forward_in);
hthread_barrier_destroy(b_id);
hthread_group_destroy(thread_id_for);
}

template<> std::complex<double>*
FFT_DSP<double>::get_auxr_3d_data() const
{
return reinterpret_cast<std::complex<double>*>(this->forward_in);
}
template FFT_DSP<float>::FFT_DSP();
template FFT_DSP<float>::~FFT_DSP();
template FFT_DSP<double>::FFT_DSP();
template FFT_DSP<double>::~FFT_DSP();

}
82 changes: 82 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_dsp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#ifndef FFT_CUDA_H
#define FFT_CUDA_H

#include "fft_base.h"
#include <ctime>
#include <cstdlib>
#include <cmath>

#include "hthread_host.h"
#include "mtfft.h"
#include "fftw3.h"

namespace ModulePW
{
template <typename FPTYPE>
class FFT_DSP : public FFT_BASE<FPTYPE>
{
public:
FFT_DSP(){};
~FFT_DSP(){};

void setupFFT() override;

void clear() override;

void cleanFFT() override;

/**
* @brief Initialize the fft parameters
* @param nx_in number of grid points in x direction
* @param ny_in number of grid points in y direction
* @param nz_in number of grid points in z direction
*
*/
virtual __attribute__((weak))
void initfft(int nx_in,
int ny_in,
int nz_in) override;

/**
* @brief Get the real space data
* @return real space data
*/
virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxr_3d_data() const override;

/**
* @brief Forward FFT in 3D
* @param in input data, complex FPTYPE
* @param out output data, complex FPTYPE
*
* This function performs the forward FFT in 3D.
*/
virtual __attribute__((weak))
void fft3D_forward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const override;
/**
* @brief Backward FFT in 3D
* @param in input data, complex FPTYPE
* @param out output data, complex FPTYPE
*
* This function performs the backward FFT in 3D.
*/
virtual __attribute__((weak))
void fft3D_backward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const override;
public:
INT cluster_id=0;
INT b_id;
INT thread_id_for=0;
PLAN* ptr_plan_forward=nullptr;
PLAN* ptr_plan_backward=nullptr;
mutable unsigned long args_for[2];
mutable unsigned long args_back[2];
mutable E * forward_in;
std::complex<float>* c_auxr_3d = nullptr; // fft space
std::complex<double>* z_auxr_3d = nullptr; // fft space

};
void test_fft_dsp();
} // namespace ModulePW
#endif
20 changes: 20 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_dsp_float.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "fft_dsp.h"
namespace ModulePW
{

template<>
void FFT_DSP<float>::setupFFT()
{

}
template<>
void FFT_DSP<float>::clear()
{

}
template<>
void FFT_DSP<float>::cleanFFT()
{

}
}
7 changes: 6 additions & 1 deletion source/module_basis/module_pw/pw_basis_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ PW_Basis_K::~PW_Basis_K()
delete[] igl2isz_k;
delete[] igl2ig_k;
delete[] gk2;
#if defined(__DSP)
delete[] ig2ixyz_k_cpu;
#endif
#if defined(__CUDA) || defined(__ROCM)
if (this->device == "gpu") {
if (this->precision == "single") {
Expand Down Expand Up @@ -357,7 +360,9 @@ void PW_Basis_K::get_ig2ixyz_k()
}
resmem_int_op()(ig2ixyz_k, this->npwk_max * this->nks);
syncmem_int_h2d_op()(this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks);
delete[] ig2ixyz_k_cpu;
#if not defined (__DSP)
delete[] this->ig2ixyz_k_cpu;
#endif
}

std::vector<int> PW_Basis_K::get_ig2ix(const int ik) const
Expand Down
16 changes: 15 additions & 1 deletion source/module_basis/module_pw/pw_basis_k.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class PW_Basis_K : public PW_Basis
int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz)
int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig
int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz

int *ig2ixyz_k_cpu = nullptr; /// [npw] map ig to ixyz,which is used in dsp fft.
double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]

// liuyu add 2023-09-06
Expand Down Expand Up @@ -135,6 +135,20 @@ class PW_Basis_K : public PW_Basis
const int ik,
const bool add = false,
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)
#if defined(__DSP)
template <typename FPTYPE>
void real2recip_3d(const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out,
const int ik,
const bool add = false,
const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns)
template <typename FPTYPE>
void recip2real_3d(const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out,
const int ik,
const bool add = false,
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)
#endif

template <typename FPTYPE, typename Device>
void real_to_recip(const Device* ctx,
Expand Down
8 changes: 8 additions & 0 deletions source/module_basis/module_pw/pw_transform_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/,
const bool add,
const double factor) const
{
#if defined(__DSP)
this->real2recip_3d(in,out,ik,add,factor);
#else
this->real2recip(in, out, ik, add, factor);
#endif
}

template <>
Expand All @@ -318,7 +322,11 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/,
const bool add,
const float factor) const
{
#if defined(__DSP)
this->recip2real_3d(in,out,add,factor);
#else
this->recip2real(in, out, ik, add, factor);
#endif
}
template <>
void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/,
Expand Down
Loading
Loading