22#include " source_basis/module_pw/kernels/pw_op.h"
33#include " pw_basis_k.h"
44#include " pw_gatherscatter.h"
5-
5+ # include " source_pw/module_pwdft/kernels/veff_op.h "
66#include < cassert>
77#include < complex>
88
@@ -337,6 +337,63 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/,
337337 this ->recip2real (in, out, ik, add, factor);
338338 #endif
339339}
340+ template <typename FPTYPE>
341+ void PW_Basis_K::convolution_cpu (const int ik,
342+ const int size,
343+ const std::complex <FPTYPE>* input,
344+ const FPTYPE* input1,
345+ std::complex <FPTYPE>* output,
346+ const bool add,
347+ const FPTYPE factor) const
348+ {
349+ ModuleBase::timer::tick (this ->classname , " convolution" );
350+ assert (this ->gamma_only == false );
351+ // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data<double>(), this->nst * this->nz);
352+ // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz
353+ auto * auxg = this ->fft_bundle .get_auxg_data <FPTYPE>();
354+ auto * auxr=this ->fft_bundle .get_auxr_data <FPTYPE>();
355+
356+ memset (auxg, 0 , this ->nst * this ->nz * 2 * 8 );
357+ const int startig = ik * this ->npwk_max ;
358+ const int npwk = this ->npwk [ik];
359+
360+ // copy the mapping form the type of stick to the 3dfft
361+ #ifdef _OPENMP
362+ #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
363+ #endif
364+ for (int igl = 0 ; igl < npwk; ++igl)
365+ {
366+ auxg[this ->igl2isz_k [igl + startig]] = input[igl];
367+ }
368+
369+ // use 3d fft backward
370+ this ->fft_bundle .fftzbac (auxg, auxg);
371+
372+ this ->gathers_scatterp (auxg, auxr);
373+
374+ this ->fft_bundle .fftxybac (auxr, auxr);
375+ for (int ir = 0 ; ir < size; ir++)
376+ {
377+ auxr[ir] *= input1[ir];
378+ }
379+
380+ // 3d fft
381+ this ->fft_bundle .fftxyfor (auxr, auxr);
382+
383+ this ->gatherp_scatters (auxr, auxg);
384+
385+ this ->fft_bundle .fftzfor (auxg, auxg);
386+ // copy the result from the auxr to the out ,while consider the add
387+ FPTYPE tmpfac = factor / FPTYPE (this ->nxyz );
388+ #ifdef _OPENMP
389+ #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
390+ #endif
391+ for (int igl = 0 ; igl < npwk; ++igl)
392+ {
393+ output[igl] += tmpfac * auxg[this ->igl2isz_k [igl + startig]];
394+ }
395+ ModuleBase::timer::tick (this ->classname , " convolution" );
396+ }
340397
341398#if (defined(__CUDA) || defined(__ROCM))
342399template <>
@@ -534,6 +591,50 @@ void PW_Basis_K::recip2real_gpu(const std::complex<FPTYPE>* in,
534591 ModuleBase::timer::tick (this ->classname , " recip_to_real gpu" );
535592}
536593
594+ template <typename FPTYPE>
595+ void PW_Basis_K::convolution_gpu (const int ik,
596+ const int size,
597+ const std::complex <FPTYPE>* input,
598+ const FPTYPE* input1,
599+ std::complex <FPTYPE>* output,
600+ const bool add,
601+ const FPTYPE factor) const
602+ {
603+ ModuleBase::timer::tick (this ->classname , " convolution" );
604+
605+ assert (this ->gamma_only == false );
606+ const base_device::DEVICE_GPU* gpux;
607+ // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz
608+
609+ base_device::memory::set_memory_op<std::complex <FPTYPE>, base_device::DEVICE_GPU>()(
610+ this ->fft_bundle .get_auxr_3d_data <FPTYPE>(),
611+ 0 ,
612+ this ->nxyz );
613+ auto * auxr = this ->fft_bundle .get_auxr_3d_data <FPTYPE>();
614+ const int startig = ik * this ->npwk_max ;
615+ const int npw_k = this ->npwk [ik];
616+
617+ // copy the mapping form the type of stick to the 3dfft
618+ set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k, this ->ig2ixyz_k + startig, input, auxr);
619+
620+ // use 3d fft backward
621+ this ->fft_bundle .fft3D_backward (auxr, auxr);
622+
623+ hamilt::veff_pw_op<FPTYPE,base_device::DEVICE_GPU>()(gpux,size,auxr,input1);
624+
625+ // 3d fft
626+ this ->fft_bundle .fft3D_forward (auxr, auxr);
627+ // copy the result from the auxr to the out ,while consider the add
628+ set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
629+ this ->nxyz ,
630+ add,
631+ factor,
632+ this ->ig2ixyz_k + startig,
633+ auxr,
634+ output);
635+ ModuleBase::timer::tick (this ->classname , " convolution" );
636+ }
637+
537638template void PW_Basis_K::real2recip_gpu<float >(const std::complex <float >*,
538639 std::complex <float >*,
539640 const int ,
@@ -557,8 +658,35 @@ template void PW_Basis_K::recip2real_gpu<double>(const std::complex<double>*,
557658 const int ,
558659 const bool ,
559660 const double ) const ;
560-
661+ template void PW_Basis_K::convolution_gpu<float >(const int ik,
662+ const int size,
663+ const std::complex <float >* input,
664+ const float * input1,
665+ std::complex <float >* output,
666+ const bool add,
667+ const float factor) const ;
668+ template void PW_Basis_K::convolution_gpu<double >(const int ik,
669+ const int size,
670+ const std::complex <double >* input,
671+ const double * input1,
672+ std::complex <double >* output,
673+ const bool add,
674+ const double factor) const ;
561675#endif
676+ template void PW_Basis_K::convolution_cpu<float >(const int ik,
677+ const int size,
678+ const std::complex <float >* input,
679+ const float * input1,
680+ std::complex <float >* output,
681+ const bool add,
682+ const float factor) const ;
683+ template void PW_Basis_K::convolution_cpu<double >(const int ik,
684+ const int size,
685+ const std::complex <double >* input,
686+ const double * input1,
687+ std::complex <double >* output,
688+ const bool add,
689+ const double factor) const ;
562690
563691template void PW_Basis_K::real2recip<float >(const float * in,
564692 std::complex <float >* out,
0 commit comments