Skip to content

Commit c1dac21

Browse files
committed
set fft_dsp
1 parent 3f8fe4f commit c1dac21

File tree

6 files changed

+237
-0
lines changed

6 files changed

+237
-0
lines changed

source/module_basis/module_pw/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ if (USE_ROCM)
1313
module_fft/fft_rocm.cpp
1414
)
1515
endif()
16+
if (USE_DSP)
17+
list (APPEND FFT_SRC
18+
module_fft/fft_dsp.cpp
19+
pw_transform_k_dsp.cpp)
20+
endif()
1621

1722
list(APPEND objects
1823
pw_basis.cpp

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ void FFT_Bundle::initfft(int nx_in,
6565

6666
if (device=="cpu")
6767
{
68+
#if defined(__DSP)
69+
if (float_flag==true)
70+
ModuleBase::WARNING_QUT("device","now dsp is not support for the float type");
71+
fft_double=make_unique<FFT_DSP<double>>();
72+
fft_double->initfft(nx_in,ny_in,nz_in);
73+
#endif
6874
fft_float = make_unique<FFT_CPU<float>>(this->fft_mode);
6975
fft_double = make_unique<FFT_CPU<double>>(this->fft_mode);
7076
if (float_flag)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// #define protected public
2+
// #define private public
3+
#include "fft_dsp.h"
4+
#include <string.h>
5+
#include <iostream>
6+
#include <vector>
7+
// #undef private
8+
// #undef protected
9+
namespace ModulePW
10+
{
11+
template<>
12+
void FFT_DSP<double>::initfft(int nx_in,int ny_in,int nz_in)
13+
{
14+
this->nx=nx_in;
15+
this->ny=ny_in;
16+
this->nz=nz_in;
17+
cluster_id = 1;
18+
}
19+
template<>
20+
void FFT_DSP<double>::setupFFT()
21+
{
22+
PROBLEM pbm_forward;
23+
PROBLEM pbm_backward;
24+
PLAN* ptr_plan_forward;
25+
PLAN* ptr_plan_backward;
26+
INT num_thread=8;
27+
28+
INT size;
29+
//open cluster id
30+
hthread_dev_open(cluster_id);
31+
//load mt.dat
32+
hthread_dat_load(cluster_id, "mt_fft_device.dat");
33+
34+
thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL);
35+
//create b_id for the barrier
36+
b_id = hthread_barrier_create(cluster_id);
37+
args_for[0] = b_id;
38+
39+
//compute the size of and malloc thread
40+
size = nx*ny*nz*2*sizeof(E);
41+
forward_in = (E*)hthread_malloc((int)cluster_id, size, HT_MEM_RW);
42+
43+
// //init 3d fft problem
44+
pbm_forward.num_dim = 3;
45+
pbm_forward.n[0] = nx;
46+
pbm_forward.n[1] = ny;
47+
pbm_forward.n[2] = nz;
48+
pbm_forward.iFFT = 0;
49+
pbm_forward.in = forward_in;
50+
pbm_forward.out = forward_in;
51+
52+
// //make ptr plan
53+
make_plan(&pbm_forward, &ptr_plan_forward, cluster_id, num_thread);
54+
ptr_plan_forward->in = forward_in;
55+
ptr_plan_forward->out = forward_in;
56+
args_for[1] = (unsigned long)ptr_plan_forward;
57+
58+
//init 3d fft problem
59+
pbm_backward.num_dim = 3; // dimensions of FFT
60+
pbm_backward.n[0] = nx; // first dimension
61+
pbm_backward.n[1] = ny; // second dimension
62+
pbm_backward.n[2] = nz; // third dimension
63+
pbm_backward.iFFT = 1; // 0 stand for forward,1 stand for backward
64+
pbm_backward.in = forward_in; // the input data
65+
pbm_backward.out = forward_in; // the output data
66+
67+
make_plan(&pbm_backward, &ptr_plan_backward, cluster_id, num_thread);
68+
ptr_plan_backward->in = forward_in;
69+
ptr_plan_backward->out = forward_in;
70+
args_back[0]=b_id;
71+
args_back[1]=(unsigned long)ptr_plan_backward;
72+
}
73+
74+
template<>
75+
void FFT_DSP<double>::fft3D_forward(std::complex<double>* in,
76+
std::complex<double>* out) const
77+
{
78+
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for);
79+
hthread_group_wait(thread_id_for);
80+
}
81+
82+
template<>
83+
void FFT_DSP<double>::fft3D_backward(std::complex<double> * in,
84+
std::complex<double>* out) const
85+
{
86+
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back);
87+
hthread_group_wait(thread_id_for);
88+
89+
}
90+
template<>
91+
void FFT_DSP<double>::cleanFFT()
92+
{
93+
if (ptr_plan_forward!=nullptr)
94+
{
95+
destroy_plan(ptr_plan_forward);
96+
ptr_plan_forward=nullptr;
97+
}
98+
if (ptr_plan_backward!=nullptr)
99+
{
100+
destroy_plan(ptr_plan_backward);
101+
ptr_plan_backward=nullptr;
102+
}
103+
}
104+
105+
template<>
106+
void FFT_DSP<double>::clear()
107+
{
108+
this->cleanFFT();
109+
hthread_free(forward_in);
110+
hthread_barrier_destroy(b_id);
111+
hthread_group_destroy(thread_id_for);
112+
}
113+
114+
template<> std::complex<double>*
115+
FFT_DSP<double>::get_auxr_3d_data() const
116+
{
117+
return reinterpret_cast<std::complex<double>*>(this->forward_in);
118+
}
119+
template FFT_DSP<float>::FFT_DSP();
120+
template FFT_DSP<float>::~FFT_DSP();
121+
template FFT_DSP<double>::FFT_DSP();
122+
template FFT_DSP<double>::~FFT_DSP();
123+
124+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#ifndef FFT_CUDA_H
2+
#define FFT_CUDA_H
3+
4+
#include "fft_base.h"
5+
#include <time.h>
6+
#include <stdlib.h>
7+
#include <math.h>
8+
9+
#include "hthread_host.h"
10+
#include "mtfft.h"
11+
#include "fftw3.h"
12+
13+
namespace ModulePW
14+
{
15+
template <typename FPTYPE>
16+
class FFT_DSP : public FFT_BASE<FPTYPE>
17+
{
18+
public:
19+
FFT_DSP(){};
20+
~FFT_DSP(){};
21+
22+
void setupFFT() override;
23+
24+
void clear() override;
25+
26+
void cleanFFT() override;
27+
28+
/**
29+
* @brief Initialize the fft parameters
30+
* @param nx_in number of grid points in x direction
31+
* @param ny_in number of grid points in y direction
32+
* @param nz_in number of grid points in z direction
33+
*
34+
*/
35+
virtual __attribute__((weak))
36+
void initfft(int nx_in,
37+
int ny_in,
38+
int nz_in) override;
39+
40+
/**
41+
* @brief Get the real space data
42+
* @return real space data
43+
*/
44+
virtual __attribute__((weak))
45+
std::complex<FPTYPE>* get_auxr_3d_data() const override;
46+
47+
/**
48+
* @brief Forward FFT in 3D
49+
* @param in input data, complex FPTYPE
50+
* @param out output data, complex FPTYPE
51+
*
52+
* This function performs the forward FFT in 3D.
53+
*/
54+
virtual __attribute__((weak))
55+
void fft3D_forward(std::complex<FPTYPE>* in,
56+
std::complex<FPTYPE>* out) const override;
57+
/**
58+
* @brief Backward FFT in 3D
59+
* @param in input data, complex FPTYPE
60+
* @param out output data, complex FPTYPE
61+
*
62+
* This function performs the backward FFT in 3D.
63+
*/
64+
virtual __attribute__((weak))
65+
void fft3D_backward(std::complex<FPTYPE>* in,
66+
std::complex<FPTYPE>* out) const override;
67+
public:
68+
INT cluster_id=0;
69+
INT b_id;
70+
INT thread_id_for=0;
71+
PLAN* ptr_plan_forward=nullptr;
72+
PLAN* ptr_plan_backward=nullptr;
73+
mutable unsigned long args_for[2];
74+
mutable unsigned long args_back[2];
75+
mutable E * forward_in;
76+
std::complex<float>* c_auxr_3d = nullptr; // fft space
77+
std::complex<double>* z_auxr_3d = nullptr; // fft space
78+
79+
};
80+
void test_fft_dsp();
81+
} // namespace ModulePW
82+
#endif
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "fft_dsp.h"
2+
namespace ModulePW
3+
{
4+
5+
template<>
6+
void FFT_DSP<float>::setupFFT()
7+
{
8+
9+
}
10+
template<>
11+
void FFT_DSP<float>::clear()
12+
{
13+
14+
}
15+
template<>
16+
void FFT_DSP<float>::cleanFFT()
17+
{
18+
19+
}
20+
}

source/module_basis/module_pw/pw_transform_k_dsp.cpp

Whitespace-only changes.

0 commit comments

Comments
 (0)