Skip to content

Commit 691b2ed

Browse files
A-006Critsium-xymohanchen
authored
Reactor FFT format and add RAII for resource handler (#6156)
* add change * add comment * fix resource handler * add RALL * update make unique * rename FFT guard * change compute mode * change compute mode * fix compute bug * add the nullpter * add barce * update compile bug --------- Co-authored-by: Critsium-xy <[email protected]> Co-authored-by: Mohan Chen <[email protected]>
1 parent 2027450 commit 691b2ed

File tree

9 files changed

+152
-67
lines changed

9 files changed

+152
-67
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ if(ENABLE_MPI)
269269
list(APPEND math_libs MPI::MPI_CXX)
270270
endif()
271271

272+
272273
if (USE_DSP)
273274
add_compile_definitions(__DSP)
274275
target_link_libraries(${ABACUS_BIN_NAME} ${OMPI_LIBRARY1})

source/module_basis/module_pw/module_fft/fft_base.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class FFT_BASE
1212
virtual ~FFT_BASE() {};
1313

1414
/**
15-
* @brief Initialize the fft parameters As virtual function.
15+
* @brief Initialize the fft parameters as virtual function.
1616
*
1717
* The function is used to initialize the fft parameters.
1818
*/
@@ -30,32 +30,40 @@ class FFT_BASE
3030
virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in);
3131

3232
/**
33-
* @brief Setup the fft Plan and data As pure virtual function.
33+
* @brief Setup the fft plan and data as pure virtual function.
3434
*
3535
* The function is set as pure virtual function.In order to
3636
* override the function in the derived class.In the derived
37-
* class, the function is used to setup the fft Plan and data.
37+
* class, the function is used to setup the fft plan and data.
3838
*/
3939
virtual void setupFFT() = 0;
4040

4141
/**
42-
* @brief Clean the fft Plan As pure virtual function.
42+
* @brief Clean the fft plan as pure virtual function.
4343
*
4444
* The function is set as pure virtual function.In order to
4545
* override the function in the derived class.In the derived
46-
* class, the function is used to clean the fft Plan.
46+
* class, the function is used to clean the fft plan.
4747
*/
4848
virtual void cleanFFT() = 0;
4949

5050
/**
51-
* @brief Clear the fft data As pure virtual function.
51+
* @brief Clear the fft data as pure virtual function.
5252
*
5353
* The function is set as pure virtual function.In order to
5454
* override the function in the derived class.In the derived
5555
* class, the function is used to clear the fft data.
5656
*/
5757
virtual void clear() = 0;
58-
58+
/**
59+
* @brief Allocate and destory the resoure in FFT running time,
60+
* Now it only used in the DSP mode.
61+
*
62+
* The function is set as pure virtual function.In order to
63+
* override the function in the derived class.In the derived
64+
* class, the function is used to allocate and destory the
65+
* resoure in FFT running time.
66+
*/
5967
virtual void resource_handler(const int flag) const {};
6068
/**
6169
* @brief Get the real space data in cpu-like fft

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,22 @@ void FFT_Bundle::initfft(int nx_in,
5050
if (this->precision == "single" || this->precision == "mixing")
5151
{
5252
float_flag = true;
53+
if (this->precision == "mixing")
54+
{
55+
double_flag = true;
56+
}
5357
#if not defined(__ENABLE_FLOAT_FFTW)
5458
if (this->device == "cpu")
5559
{
5660
ModuleBase::WARNING_QUIT("FFT_Bundle", "Please enable float fftw in the cmake to use float fft");
5761
}
5862
#endif
5963
}
60-
if (this->precision == "double" || this->precision == "mixing")
64+
else if (this->precision == "double")
6165
{
6266
double_flag = true;
67+
}else{
68+
ModuleBase::WARNING_QUIT("FFT_Bundle", "Please set the precision to single or double or mixing");
6369
}
6470
#if defined(__DSP)
6571
if (device == "dsp")
@@ -70,24 +76,23 @@ void FFT_Bundle::initfft(int nx_in,
7076
}
7177
fft_double = make_unique<FFT_DSP<double>>();
7278
fft_double->initfft(nx_in, ny_in, nz_in);
73-
}
79+
}else
7480
#endif
7581
if (device == "cpu")
7682
{
77-
fft_float = make_unique<FFT_CPU<float>>(this->fft_mode);
78-
fft_double = make_unique<FFT_CPU<double>>(this->fft_mode);
7983
if (float_flag)
8084
{
85+
fft_float = make_unique<FFT_CPU<float>>(this->fft_mode);
8186
fft_float
8287
->initfft(nx_in, ny_in, nz_in, lixy_in, rixy_in, ns_in, nplane_in, nproc_in, gamma_only_in, xprime_in);
8388
}
8489
if (double_flag)
8590
{
91+
fft_double = make_unique<FFT_CPU<double>>(this->fft_mode);
8692
fft_double
8793
->initfft(nx_in, ny_in, nz_in, lixy_in, rixy_in, ns_in, nplane_in, nproc_in, gamma_only_in, xprime_in);
8894
}
89-
}
90-
if (device == "gpu")
95+
}else if (device == "gpu")
9196
{
9297
#if defined(__ROCM)
9398
fft_float = make_unique<FFT_ROCM<float>>();
@@ -100,6 +105,8 @@ void FFT_Bundle::initfft(int nx_in,
100105
fft_double = make_unique<FFT_CUDA<double>>();
101106
fft_double->initfft(nx_in, ny_in, nz_in);
102107
#endif
108+
}else{
109+
ModuleBase::WARNING_QUIT("FFT_Bundle", "Please set the device to cpu or gpu or dsp");
103110
}
104111
}
105112

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,5 +203,18 @@ class FFT_Bundle
203203
std::string device = "cpu";
204204
std::string precision = "double";
205205
};
206+
// Use RAII (Resource Acquisition Is Initialization) to
207+
// control the resources used by hthread when setting the DSP
208+
struct FFT_Guard
209+
{
210+
const FFT_Bundle& fft_;
211+
FFT_Guard(const FFT_Bundle& fft) : fft_(fft)
212+
{fft_.resource_handler(1);}
213+
~FFT_Guard()
214+
{
215+
fft_.resource_handler(0);
216+
}
217+
};
218+
206219
} // namespace ModulePW
207220
#endif // FFT_H

source/module_basis/module_pw/module_fft/fft_dsp.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void FFT_DSP<double>::setupFFT()
6363
template <>
6464
void FFT_DSP<double>::resource_handler(const int flag) const
6565
{
66-
if (flag==0)
66+
if (flag == 0)
6767
{
6868
hthread_barrier_destroy(b_id);
6969
hthread_group_destroy(thread_id_for);
@@ -76,6 +76,8 @@ void FFT_DSP<double>::resource_handler(const int flag) const
7676
b_id = hthread_barrier_create(cluster_id);
7777
args_for[0] = b_id;
7878
args_back[0] = b_id;
79+
}else{
80+
ModuleBase::WARNING_QUIT("FFT_DSP", "Error use of fft resource handle");
7981
}
8082
}
8183
template <>

source/module_basis/module_pw/module_fft/fft_dsp.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
namespace ModulePW
1414
{
15+
1516
template <typename FPTYPE>
1617
class FFT_DSP : public FFT_BASE<FPTYPE>
1718
{
@@ -24,7 +25,12 @@ class FFT_DSP : public FFT_BASE<FPTYPE>
2425
void clear() override;
2526

2627
void cleanFFT() override;
27-
28+
/**
29+
* @brief Control the allocation or deallocation of hthread
30+
* resource
31+
* @param flag 0: deallocate, 1: allocate
32+
*/
33+
void resource_handler(const int flag) const override;
2834
/**
2935
* @brief Initialize the fft parameters
3036
* @param nx_in number of grid points in x direction

source/module_basis/module_pw/pw_basis_k.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class PW_Basis_K : public PW_Basis
187187
const typename GetTypeReal<TK>::type factor = 1.0) const
188188
{
189189
#if defined(__DSP)
190-
this->recip2real_dsp(in, out, ik, add, factor);
190+
this->real2recip_dsp(in, out, ik, add, factor);
191191
#else
192192
this->real2recip(in,out,ik,add,factor);
193193
#endif

source/module_basis/module_pw/pw_transform_k_dsp.cpp

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,30 @@
88
#if defined (__DSP)
99
namespace ModulePW
1010
{
11-
template <typename FPTYPE>
12-
void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in,
13-
std::complex<FPTYPE>* out,
11+
template <>
12+
void PW_Basis_K::real2recip_dsp(const std::complex<float>* in,
13+
std::complex<float>* out,
1414
const int ik,
1515
const bool add,
16-
const FPTYPE factor) const
16+
const float factor) const
17+
{
18+
19+
}
20+
template <>
21+
void PW_Basis_K::recip2real_dsp(const std::complex<float>* in,
22+
std::complex<float>* out,
23+
const int ik,
24+
const bool add,
25+
const float factor) const
26+
{
27+
28+
}
29+
template <>
30+
void PW_Basis_K::real2recip_dsp(const std::complex<double>* in,
31+
std::complex<double>* out,
32+
const int ik,
33+
const bool add,
34+
const double factor) const
1735
{
1836
const base_device::DEVICE_CPU* ctx;
1937
const base_device::DEVICE_GPU* gpux;
@@ -31,20 +49,20 @@ void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in,
3149
auxr);
3250
this->fft_bundle.resource_handler(0);
3351
// copy the result from the auxr to the out ,while consider the add
34-
set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>()(npw_k,
52+
set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(npw_k,
3553
this->nxyz,
3654
add,
3755
factor,
3856
this->ig2ixyz_k_cpu.data() + startig,
3957
auxr,
4058
out);
4159
}
42-
template <typename FPTYPE>
43-
void PW_Basis_K::recip2real_dsp(const std::complex<FPTYPE>* in,
44-
std::complex<FPTYPE>* out,
60+
template <>
61+
void PW_Basis_K::recip2real_dsp(const std::complex<double>* in,
62+
std::complex<double>* out,
4563
const int ik,
4664
const bool add,
47-
const FPTYPE factor) const
65+
const double factor) const
4866
{
4967
assert(this->gamma_only == false);
5068
const base_device::DEVICE_CPU* ctx;
@@ -128,16 +146,16 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
128146
ModuleBase::timer::tick(this->classname, "convolution");
129147
}
130148

131-
// template void PW_Basis_K::real2recip_dsp<float>(const std::complex<float>* in,
132-
// std::complex<float>* out,
133-
// const int ik,
134-
// const bool add,
135-
// const float factor) const; // in:(nplane,nx*ny) ; out(nz, ns)
136-
// template void PW_Basis_K::recip2real_dsp<float>(const std::complex<float>* in,
137-
// std::complex<float>* out,
138-
// const int ik,
139-
// const bool add,
140-
// const float factor) const; // in:(nz, ns) ; out(nplane,nx*ny)
149+
template void PW_Basis_K::real2recip_dsp<float>(const std::complex<float>* in,
150+
std::complex<float>* out,
151+
const int ik,
152+
const bool add,
153+
const float factor) const; // in:(nplane,nx*ny) ; out(nz, ns)
154+
template void PW_Basis_K::recip2real_dsp<float>(const std::complex<float>* in,
155+
std::complex<float>* out,
156+
const int ik,
157+
const bool add,
158+
const float factor) const; // in:(nz, ns) ; out(nplane,nx*ny)
141159

142160
template void PW_Basis_K::real2recip_dsp<double>(const std::complex<double>* in,
143161
std::complex<double>* out,

source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,52 +52,82 @@ void Veff<OperatorPW<T, Device>>::act(
5252
{
5353
setmem_complex_op()(tmhpsi, 0, nbasis*nbands/npol);
5454
}
55-
5655
int max_npw = nbasis / npol;
5756
const int current_spin = this->isk[this->ik];
58-
57+
const int psi_offset= max_npw * npol;
5958
#ifdef __DSP
60-
wfcpw->fft_bundle.resource_handler(1);
61-
#endif
62-
63-
for (int ib = 0; ib < nbands; ib += npol)
59+
if (npol == 1)
60+
{
61+
ModulePW::FFT_Guard guard(wfcpw->fft_bundle);
62+
for (int ib = 0; ib < nbands; ib += npol)
63+
{
64+
wfcpw->convolution(this->ctx,
65+
this->ik,
66+
this->veff_col,
67+
tmpsi_in,
68+
this->veff + current_spin * this->veff_col,
69+
tmhpsi,
70+
true);
71+
tmhpsi += psi_offset;
72+
tmpsi_in += psi_offset;
73+
}
74+
}else if (npol == 2)
75+
{
76+
const Real* current_veff[4]={nullptr};
77+
for (int is = 0; is < 4; is++)
78+
{
79+
current_veff[is] = this->veff + is * this->veff_col;
80+
}
81+
for (int ib = 0; ib < nbands; ib += npol)
82+
{
83+
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
84+
wfcpw->recip_to_real<T, Device>(tmpsi_in + max_npw, this->porter1, this->ik);
85+
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
86+
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
87+
wfcpw->real_to_recip<T, Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
88+
tmhpsi += psi_offset;
89+
tmpsi_in += psi_offset;
90+
}
91+
}else{
92+
ModuleBase::WARNING_QUIT("VeffPW", "npol should be 1 or 2 or veff_col equal to 0\n");
93+
}
94+
#else
95+
if (npol == 1)
6496
{
65-
if (npol == 1)
97+
for (int ib = 0; ib < nbands; ib += npol)
6698
{
67-
wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
99+
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
68100
// NOTICE: when MPI threads are larger than the number of Z grids
69101
// veff would contain nothing, and nothing should be done in real space
70102
// but the 3DFFT can not be skipped, it will cause hanging
71-
if(this->veff_col != 0)
72-
{
73-
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
74-
}
75-
wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
103+
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
104+
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
105+
tmhpsi += psi_offset;
106+
tmpsi_in += psi_offset;
76107
}
77-
else
108+
}
109+
else if (npol == 2)
110+
{
111+
const Real* current_veff[4]={nullptr};
112+
for (int is = 0; is < 4; is++)
113+
{
114+
current_veff[is] = this->veff + is * this->veff_col;
115+
}
116+
for (int ib = 0; ib < nbands; ib += npol)
78117
{
79118
// FFT to real space and do things.
80-
wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
81-
wfcpw->recip_to_real<T,Device>(tmpsi_in + max_npw, this->porter1, this->ik);
82-
if(this->veff_col != 0)
83-
{
84-
/// denghui added at 20221109
85-
const Real* current_veff[4];
86-
for(int is = 0; is < 4; is++)
87-
{
88-
current_veff[is] = this->veff + is * this->veff_col ; // for CPU device
89-
}
90-
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
91-
}
119+
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
120+
wfcpw->recip_to_real<T, Device>(tmpsi_in + max_npw, this->porter1, this->ik);
121+
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
92122
// FFT back to G space.
93-
wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
94-
wfcpw->real_to_recip<T,Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
123+
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
124+
wfcpw->real_to_recip<T, Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
125+
tmhpsi += psi_offset;
126+
tmpsi_in += psi_offset;
95127
}
96-
tmhpsi += max_npw * npol;
97-
tmpsi_in += max_npw * npol;
128+
}else{
129+
ModuleBase::WARNING_QUIT("VeffPW", "npol should be 1 or 2 or veff_col equal to 0\n");
98130
}
99-
#ifdef __DSP
100-
wfcpw->fft_bundle.resource_handler(0);
101131
#endif
102132
ModuleBase::timer::tick("Operator", "veff_pw");
103133
}

0 commit comments

Comments
 (0)