Skip to content

Commit b896780

Browse files
committed
add the convolution and allocate or destroy the b_id
1 parent 8c18170 commit b896780

File tree

7 files changed

+124
-100
lines changed

7 files changed

+124
-100
lines changed

source/module_basis/module_pw/module_fft/fft_base.h

Lines changed: 66 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -7,166 +7,150 @@ namespace ModulePW
77
template <typename FPTYPE>
88
class FFT_BASE
99
{
10-
public:
10+
public:
11+
FFT_BASE() {};
12+
virtual ~FFT_BASE() {};
1113

12-
FFT_BASE(){};
13-
virtual ~FFT_BASE(){};
14-
1514
/**
1615
* @brief Initialize the fft parameters As virtual function.
17-
*
16+
*
1817
* The function is used to initialize the fft parameters.
1918
*/
20-
virtual __attribute__((weak))
21-
void initfft(int nx_in,
22-
int ny_in,
23-
int nz_in,
24-
int lixy_in,
25-
int rixy_in,
26-
int ns_in,
27-
int nplane_in,
28-
int nproc_in,
29-
bool gamma_only_in,
30-
bool xprime_in = true);
31-
32-
virtual __attribute__((weak))
33-
void initfft(int nx_in,
34-
int ny_in,
35-
int nz_in);
19+
virtual __attribute__((weak)) void initfft(int nx_in,
20+
int ny_in,
21+
int nz_in,
22+
int lixy_in,
23+
int rixy_in,
24+
int ns_in,
25+
int nplane_in,
26+
int nproc_in,
27+
bool gamma_only_in,
28+
bool xprime_in = true);
29+
30+
virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in);
3631

3732
/**
3833
* @brief Setup the fft Plan and data As pure virtual function.
39-
*
34+
*
4035
* The function is set as pure virtual function.In order to
4136
* override the function in the derived class.In the derived
4237
* class, the function is used to setup the fft Plan and data.
4338
*/
44-
virtual void setupFFT()=0;
39+
virtual void setupFFT() = 0;
4540

4641
/**
4742
* @brief Clean the fft Plan As pure virtual function.
48-
*
43+
*
4944
* The function is set as pure virtual function.In order to
5045
* override the function in the derived class.In the derived
5146
* class, the function is used to clean the fft Plan.
5247
*/
53-
virtual void cleanFFT()=0;
54-
48+
virtual void cleanFFT() = 0;
49+
5550
/**
5651
* @brief Clear the fft data As pure virtual function.
57-
*
52+
*
5853
* The function is set as pure virtual function.In order to
5954
* override the function in the derived class.In the derived
6055
* class, the function is used to clear the fft data.
6156
*/
62-
virtual void clear()=0;
63-
57+
virtual void clear() = 0;
58+
59+
virtual void resource_handler(const int flag) const {};
6460
/**
6561
* @brief Get the real space data in cpu-like fft
66-
*
62+
*
6763
* The function is used to get the real space data.While the
6864
* FFT_BASE is an abstract class,the function will be override,
69-
* The attribute weak is used to avoid define the function.
65+
* The attribute weak is used to avoid define the function.
7066
*/
71-
virtual __attribute__((weak))
72-
FPTYPE* get_rspace_data() const;
67+
virtual __attribute__((weak)) FPTYPE* get_rspace_data() const;
7368

74-
virtual __attribute__((weak))
75-
std::complex<FPTYPE>* get_auxr_data() const;
69+
virtual __attribute__((weak)) std::complex<FPTYPE>* get_auxr_data() const;
7670

77-
virtual __attribute__((weak))
78-
std::complex<FPTYPE>* get_auxg_data() const;
71+
virtual __attribute__((weak)) std::complex<FPTYPE>* get_auxg_data() const;
7972

8073
/**
8174
* @brief Get the auxiliary real space data in 3D
82-
*
75+
*
8376
* The function is used to get the auxiliary real space data in 3D.
8477
* While the FFT_BASE is an abstract class,the function will be override,
8578
* The attribute weak is used to avoid define the function.
8679
*/
87-
virtual __attribute__((weak))
88-
std::complex<FPTYPE>* get_auxr_3d_data() const;
80+
virtual __attribute__((weak)) std::complex<FPTYPE>* get_auxr_3d_data() const;
8981

90-
//forward fft in x-y direction
82+
// forward fft in x-y direction
9183

9284
/**
9385
* @brief Forward FFT in x-y direction
9486
* @param in input data
9587
* @param out output data
96-
*
88+
*
9789
* This function performs the forward FFT in the x-y direction.
9890
* It involves two axes, x and y. The FFT is applied multiple times
99-
* along the left and right boundaries in the primary direction(which is
100-
* determined by the xprime flag).Notably, the Y axis operates in
91+
* along the left and right boundaries in the primary direction(which is
92+
* determined by the xprime flag).Notably, the Y axis operates in
10193
* "many-many-FFT" mode.
10294
*/
103-
virtual __attribute__((weak))
104-
void fftxyfor(std::complex<FPTYPE>* in,
105-
std::complex<FPTYPE>* out) const;
95+
virtual __attribute__((weak)) void fftxyfor(std::complex<FPTYPE>* in,
96+
std::complex<FPTYPE>* out) const;
10697

107-
virtual __attribute__((weak))
108-
void fftxybac(std::complex<FPTYPE>* in,
109-
std::complex<FPTYPE>* out) const;
98+
virtual __attribute__((weak)) void fftxybac(std::complex<FPTYPE>* in,
99+
std::complex<FPTYPE>* out) const;
110100

111101
/**
112102
* @brief Forward FFT in z direction
113103
* @param in input data
114104
* @param out output data
115-
*
105+
*
116106
* This function performs the forward FFT in the z direction.
117107
* It involves only one axis, z. The FFT is applied only once.
118108
* Notably, the Z axis operates in many FFT with nz*ns.
119109
*/
120-
virtual __attribute__((weak))
121-
void fftzfor(std::complex<FPTYPE>* in,
122-
std::complex<FPTYPE>* out) const;
123-
124-
virtual __attribute__((weak))
125-
void fftzbac(std::complex<FPTYPE>* in,
126-
std::complex<FPTYPE>* out) const;
110+
virtual __attribute__((weak)) void fftzfor(std::complex<FPTYPE>* in,
111+
std::complex<FPTYPE>* out) const;
112+
113+
virtual __attribute__((weak)) void fftzbac(std::complex<FPTYPE>* in,
114+
std::complex<FPTYPE>* out) const;
127115

128116
/**
129117
* @brief Forward FFT in x-y direction with real to complex
130118
* @param in input data, real type
131119
* @param out output data, complex type
132-
*
133-
* This function performs the forward FFT in the x-y direction
120+
*
121+
* This function performs the forward FFT in the x-y direction
134122
* with real to complex.There is no difference between fftxyfor.
135123
*/
136-
virtual __attribute__((weak))
137-
void fftxyr2c(FPTYPE* in,
138-
std::complex<FPTYPE>* out) const;
139-
140-
virtual __attribute__((weak))
141-
void fftxyc2r(std::complex<FPTYPE>* in,
142-
FPTYPE* out) const;
143-
124+
virtual __attribute__((weak)) void fftxyr2c(FPTYPE* in,
125+
std::complex<FPTYPE>* out) const;
126+
127+
virtual __attribute__((weak)) void fftxyc2r(std::complex<FPTYPE>* in,
128+
FPTYPE* out) const;
129+
144130
/**
145131
* @brief Forward FFT in 3D
146132
* @param in input data
147133
* @param out output data
148-
*
134+
*
149135
* This function performs the forward FFT for gpu-like fft.
150136
* It involves three axes, x, y, and z. The FFT is applied multiple times
151137
* for fft3D_forward.
152138
*/
153-
virtual __attribute__((weak))
154-
void fft3D_forward(std::complex<FPTYPE>* in,
155-
std::complex<FPTYPE>* out) const;
156-
157-
virtual __attribute__((weak))
158-
void fft3D_backward(std::complex<FPTYPE>* in,
159-
std::complex<FPTYPE>* out) const;
160-
161-
protected:
162-
int nx=0;
163-
int ny=0;
164-
int nz=0;
139+
virtual __attribute__((weak)) void fft3D_forward(std::complex<FPTYPE>* in,
140+
std::complex<FPTYPE>* out) const;
141+
142+
virtual __attribute__((weak)) void fft3D_backward(std::complex<FPTYPE>* in,
143+
std::complex<FPTYPE>* out) const;
144+
145+
protected:
146+
int nx = 0;
147+
int ny = 0;
148+
int nz = 0;
165149
};
166150

167151
template FFT_BASE<float>::FFT_BASE();
168152
template FFT_BASE<double>::FFT_BASE();
169153
template FFT_BASE<float>::~FFT_BASE();
170154
template FFT_BASE<double>::~FFT_BASE();
171-
}
155+
} // namespace ModulePW
172156
#endif // FFT_BASE_H

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,20 @@ void FFT_Bundle::clear()
146146
}
147147
}
148148

149+
void FFT_Bundle::resource_handler(const int flag) const
150+
{
151+
if (this->device=="dsp")
152+
{
153+
if (double_flag)
154+
{
155+
fft_double->resource_handler(flag);
156+
}
157+
if (float_flag)
158+
{
159+
fft_float->resource_handler(flag);
160+
}
161+
}
162+
}
149163
template <>
150164
void FFT_Bundle::fftxyfor(std::complex<float>* in, std::complex<float>* out) const
151165
{

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class FFT_Bundle
8181

8282
void clear();
8383

84+
void resource_handler(const int flag) const;
8485
/**
8586
* @brief Get the real space data.
8687
* @return FPTYPE* the real space data.

source/module_basis/module_pw/module_fft/fft_dsp.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,33 +60,36 @@ void FFT_DSP<double>::setupFFT()
6060
ptr_plan_backward->out = forward_in;
6161
args_back[1] = (unsigned long)ptr_plan_backward;
6262
}
63-
63+
template <>
64+
void FFT_DSP<double>::resource_handler(const int flag) const
65+
{
66+
if (flag==0)
67+
{
68+
hthread_barrier_destroy(b_id);
69+
hthread_group_destroy(thread_id_for);
70+
}
71+
else if (flag==1)
72+
{
73+
INT num_thread = 8;
74+
thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL);
75+
// create b_id for the barrier
76+
b_id = hthread_barrier_create(cluster_id);
77+
args_for[0] = b_id;
78+
args_back[0] = b_id;
79+
}
80+
}
6481
template <>
6582
void FFT_DSP<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
6683
{
67-
INT num_thread = 8;
68-
thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL);
69-
// create b_id for the barrier
70-
b_id = hthread_barrier_create(cluster_id);
71-
args_for[0] = b_id;
7284
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for);
7385
hthread_group_wait(thread_id_for);
74-
hthread_barrier_destroy(b_id);
75-
hthread_group_destroy(thread_id_for);
7686
}
7787

7888
template <>
7989
void FFT_DSP<double>::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const
8090
{
81-
INT num_thread = 8;
82-
thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL);
83-
// create b_id for the barrier
84-
b_id = hthread_barrier_create(cluster_id);
85-
args_back[0] = b_id;
8691
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back);
8792
hthread_group_wait(thread_id_for);
88-
hthread_barrier_destroy(b_id);
89-
hthread_group_destroy(thread_id_for);
9093
}
9194
template <>
9295
void FFT_DSP<double>::cleanFFT()

source/module_basis/module_pw/module_fft/fft_dsp_float.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,10 @@ template<>
1616
void FFT_DSP<float>::cleanFFT()
1717
{
1818

19+
}
20+
template<>
21+
void FFT_DSP<float>::resource_handler(const int flag) const
22+
{
23+
1924
}
2025
}

source/module_basis/module_pw/pw_transform_k_dsp.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in,
2626
memcpy(auxr, in, this->nrxx * 2 * 8);
2727

2828
// 3d fft
29-
this->fft_bundle.fft3D_forward(gpux, auxr, auxr);
30-
29+
this->fft_bundle.resource_handler(1);
30+
this->fft_bundle.fft3D_forward(gpux,
31+
auxr,
32+
auxr);
33+
this->fft_bundle.resource_handler(0);
3134
// copy the result from the auxr to the out ,while consider the add
3235
set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>()(ctx,
3336
npw_k,
@@ -57,7 +60,9 @@ void PW_Basis_K::recip2real_dsp(const std::complex<FPTYPE>* in,
5760
// copy the mapping form the type of stick to the 3dfft
5861
set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr);
5962
// use 3d fft backward
63+
this->fft_bundle.resource_handler(1);
6064
this->fft_bundle.fft3D_backward(gpux, auxr, auxr);
65+
this->fft_bundle.resource_handler(0);
6166
if (add)
6267
{
6368
const int one = 1;

source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ void Veff<OperatorPW<T, Device>>::act(
5353

5454
int max_npw = nbasis / npol;
5555
const int current_spin = this->isk[this->ik];
56-
56+
#ifdef __DSP
57+
wfcpw->fft_bundle.resource_handler(1);
58+
#endif
5759
// T *porter = new T[wfcpw->nmaxgr];
5860
for (int ib = 0; ib < nbands; ib += npol)
5961
{
@@ -75,6 +77,13 @@ void Veff<OperatorPW<T, Device>>::act(
7577
}
7678
// wfcpw->real2recip(porter, tmhpsi, this->ik, true);
7779
wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
80+
// wfcpw->convolution(this->ctx,
81+
// this->ik,
82+
// this->veff_col,
83+
// tmpsi_in,
84+
// this->veff+current_spin,
85+
// tmhpsi,
86+
// true);
7887
}
7988
else
8089
{
@@ -111,6 +120,9 @@ void Veff<OperatorPW<T, Device>>::act(
111120
tmhpsi += max_npw * npol;
112121
tmpsi_in += max_npw * npol;
113122
}
123+
#ifdef __DSP
124+
wfcpw->fft_bundle.resource_handler(0);
125+
#endif
114126
ModuleBase::timer::tick("Operator", "VeffPW");
115127
}
116128

0 commit comments

Comments
 (0)