1+ #include " module_base/timer.h"
2+ #include " module_basis/module_pw/kernels/pw_op.h"
3+ #include " pw_basis_k.h"
4+
5+ #include < cassert>
6+ #include < complex>
7+ #include < string>
8+ namespace ModulePW
9+ {
10+ template <typename FPTYPE>
11+ void PW_Basis_K::real2recip_3d (const std::complex <FPTYPE>* in,
12+ std::complex <FPTYPE>* out,
13+ const int ik,
14+ const bool add,
15+ const FPTYPE factor) const
16+ {
17+ ModuleBase::timer::tick (this ->classname ," real2recip_3d" );
18+ const base_device::DEVICE_CPU* ctx;
19+ const base_device::DEVICE_GPU* gpux;
20+ assert (this ->gamma_only == false );
21+ auto * auxr = this ->fft_bundle .get_auxr_3d_data <double >();
22+
23+ const int startig = ik * this ->npwk_max ;
24+ const int npw_k = this ->npwk [ik];
25+ memcpy (auxr,in,this ->nrxx *2 *8 );
26+ this ->fft_bundle .fft3D_forward (gpux,
27+ auxr,
28+ auxr);
29+ set_real_to_recip_output_op<double , base_device::DEVICE_CPU>()(ctx,
30+ npw_k,
31+ this ->nxyz ,
32+ add,
33+ factor,
34+ this ->ig2ixyz_k_cpu + startig,
35+ this ->fft_bundle .get_auxr_3d_data <double >(),
36+ out);
37+ ModuleBase::timer::tick (this ->classname ," real2recip_3d" );
38+ }
39+
40+ template <typename FPTYPE>
41+ void PW_Basis_K::recip2real_3d (const std::complex <FPTYPE>* in,
42+ std::complex <FPTYPE>* out,
43+ const int ik,
44+ const bool add,
45+ const FPTYPE factor) const
46+ {
47+ ModuleBase::timer::tick (this ->classname ," recip2real_3d" );
48+
49+ assert (this ->gamma_only == false );
50+ const base_device::DEVICE_CPU* ctx;
51+ const base_device::DEVICE_GPU* gpux;
52+ auto * auxr = this ->fft_bundle .get_auxr_3d_data <double >();
53+ memset (auxr,0 ,this ->nrxx *2 *8 );
54+ const int startig = ik * this ->npwk_max ;
55+ const int npw_k = this ->npwk [ik];
56+
57+ set_3d_fft_box_op<double , base_device::DEVICE_CPU>()(ctx,
58+ npw_k,
59+ this ->ig2ixyz_k_cpu + startig,
60+ in,
61+ auxr);
62+ this ->fft_bundle .fft3D_backward (gpux,auxr,auxr);
63+ set_recip_to_real_output_op<double , base_device::DEVICE_CPU>()(ctx,
64+ this ->nrxx ,
65+ add,
66+ factor,
67+ auxr,
68+ out);
69+ ModuleBase::timer::tick (this ->classname ," recip2real_3d" );
70+ }
71+
72+ template void PW_Basis_K::real2recip_3d<double >(const std::complex <double >* in,
73+ std::complex <double >* out,
74+ const int ik,
75+ const bool add,
76+ const double factor) const ; // in:(nplane,nx*ny) ; out(nz, ns)
77+ template void PW_Basis_K::recip2real_3d<double >(const std::complex <double >* in,
78+ std::complex <double >* out,
79+ const int ik,
80+ const bool add,
81+ const double factor) const ; // in:(nz, ns) ; out(nplane,nx*ny)
82+ }
0 commit comments