Skip to content

Commit 4b770fb

Browse files
committed
add convolution for psi
1 parent 2f59a6f commit 4b770fb

File tree

3 files changed

+177
-7
lines changed

3 files changed

+177
-7
lines changed

source/source_basis/module_pw/pw_basis_k.h

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ class PW_Basis_K : public PW_Basis
158158
const int ik,
159159
const bool add = false,
160160
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)
161-
162161
#endif
163162

164163
template <typename FPTYPE, typename Device>
@@ -176,7 +175,6 @@ class PW_Basis_K : public PW_Basis
176175
const bool add = false,
177176
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)
178177

179-
180178
template <typename TK,
181179
typename Device,
182180
typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
@@ -245,6 +243,47 @@ class PW_Basis_K : public PW_Basis
245243
{
246244
this->recip2real_gpu(in, out, ik, add, factor);
247245
}
246+
template <typename FPTYPE, typename Device,
247+
typename std::enable_if<std::is_same<Device, base_device::DEVICE_GPU>::value, int>::type = 0>
248+
void convolution(const int ik,
249+
const int size,
250+
const FPTYPE* input,
251+
const typename GetTypeReal<FPTYPE>::type* input1,
252+
FPTYPE* output,
253+
const bool add = false,
254+
const typename GetTypeReal<FPTYPE>::type factor =1.0) const
255+
{
256+
this->convolution_gpu(ik, size, input, input1, output, add, factor);
257+
}
258+
template <typename FPTYPE, typename Device,
259+
typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
260+
void convolution(const int ik,
261+
const int size,
262+
const FPTYPE* input,
263+
const typename GetTypeReal<FPTYPE>::type* input1,
264+
FPTYPE* output,
265+
const bool add = false,
266+
const typename GetTypeReal<FPTYPE>::type factor =1.0) const
267+
{
268+
this->convolution_cpu(ik, size, input, input1, output, add, factor);
269+
}
270+
template <typename FPTYPE>
271+
void convolution_cpu(const int ik,
272+
const int size,
273+
const std::complex<FPTYPE>* input,
274+
const FPTYPE* input1,
275+
std::complex<FPTYPE>* output,
276+
const bool add = false,
277+
const FPTYPE factor = 1.0) const;
278+
279+
template <typename FPTYPE>
280+
void convolution_gpu(const int ik,
281+
const int size,
282+
const std::complex<FPTYPE>* input,
283+
const FPTYPE* input1,
284+
std::complex<FPTYPE>* output,
285+
const bool add = false,
286+
const FPTYPE factor = 1.0) const;
248287

249288
public:
250289
//operator:

source/source_basis/module_pw/pw_transform_k.cpp

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include "source_basis/module_pw/kernels/pw_op.h"
33
#include "pw_basis_k.h"
44
#include "pw_gatherscatter.h"
5-
5+
#include "source_pw/module_pwdft/kernels/veff_op.h"
66
#include <cassert>
77
#include <complex>
88

@@ -337,6 +337,63 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/,
337337
this->recip2real(in, out, ik, add, factor);
338338
#endif
339339
}
340+
template <typename FPTYPE>
341+
void PW_Basis_K::convolution_cpu(const int ik,
342+
const int size,
343+
const std::complex<FPTYPE>* input,
344+
const FPTYPE* input1,
345+
std::complex<FPTYPE>* output,
346+
const bool add,
347+
const FPTYPE factor) const
348+
{
349+
ModuleBase::timer::tick(this->classname, "convolution");
350+
assert(this->gamma_only == false);
351+
// ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data<double>(), this->nst * this->nz);
352+
// memset the auxr of 0 in the auxr,here the len of the auxr is nxyz
353+
auto* auxg = this->fft_bundle.get_auxg_data<FPTYPE>();
354+
auto* auxr=this->fft_bundle.get_auxr_data<FPTYPE>();
355+
356+
memset(auxg, 0, this->nst * this->nz * 2 * 8);
357+
const int startig = ik * this->npwk_max;
358+
const int npwk = this->npwk[ik];
359+
360+
// copy the mapping form the type of stick to the 3dfft
361+
#ifdef _OPENMP
362+
#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
363+
#endif
364+
for (int igl = 0; igl < npwk; ++igl)
365+
{
366+
auxg[this->igl2isz_k[igl + startig]] = input[igl];
367+
}
368+
369+
// use 3d fft backward
370+
this->fft_bundle.fftzbac(auxg, auxg);
371+
372+
this->gathers_scatterp(auxg, auxr);
373+
374+
this->fft_bundle.fftxybac(auxr, auxr);
375+
for (int ir = 0; ir < size; ir++)
376+
{
377+
auxr[ir] *= input1[ir];
378+
}
379+
380+
// 3d fft
381+
this->fft_bundle.fftxyfor(auxr, auxr);
382+
383+
this->gatherp_scatters(auxr, auxg);
384+
385+
this->fft_bundle.fftzfor(auxg, auxg);
386+
// copy the result from the auxr to the out ,while consider the add
387+
FPTYPE tmpfac = factor / FPTYPE(this->nxyz);
388+
#ifdef _OPENMP
389+
#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
390+
#endif
391+
for (int igl = 0; igl < npwk; ++igl)
392+
{
393+
output[igl] += tmpfac * auxg[this->igl2isz_k[igl + startig]];
394+
}
395+
ModuleBase::timer::tick(this->classname, "convolution");
396+
}
340397

341398
#if (defined(__CUDA) || defined(__ROCM))
342399
template <>
@@ -534,6 +591,50 @@ void PW_Basis_K::recip2real_gpu(const std::complex<FPTYPE>* in,
534591
ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
535592
}
536593

594+
template <typename FPTYPE>
595+
void PW_Basis_K::convolution_gpu(const int ik,
596+
const int size,
597+
const std::complex<FPTYPE>* input,
598+
const FPTYPE* input1,
599+
std::complex<FPTYPE>* output,
600+
const bool add,
601+
const FPTYPE factor) const
602+
{
603+
ModuleBase::timer::tick(this->classname, "convolution");
604+
605+
assert(this->gamma_only == false);
606+
const base_device::DEVICE_GPU* gpux;
607+
// memset the auxr of 0 in the auxr,here the len of the auxr is nxyz
608+
609+
base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
610+
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
611+
0,
612+
this->nxyz);
613+
auto* auxr = this->fft_bundle.get_auxr_3d_data<FPTYPE>();
614+
const int startig = ik * this->npwk_max;
615+
const int npw_k = this->npwk[ik];
616+
617+
// copy the mapping form the type of stick to the 3dfft
618+
set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k, this->ig2ixyz_k + startig, input, auxr);
619+
620+
// use 3d fft backward
621+
this->fft_bundle.fft3D_backward(auxr, auxr);
622+
623+
hamilt::veff_pw_op<FPTYPE,base_device::DEVICE_GPU>()(gpux,size,auxr,input1);
624+
625+
// 3d fft
626+
this->fft_bundle.fft3D_forward(auxr, auxr);
627+
// copy the result from the auxr to the out ,while consider the add
628+
set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
629+
this->nxyz,
630+
add,
631+
factor,
632+
this->ig2ixyz_k + startig,
633+
auxr,
634+
output);
635+
ModuleBase::timer::tick(this->classname, "convolution");
636+
}
637+
537638
template void PW_Basis_K::real2recip_gpu<float>(const std::complex<float>*,
538639
std::complex<float>*,
539640
const int,
@@ -557,8 +658,35 @@ template void PW_Basis_K::recip2real_gpu<double>(const std::complex<double>*,
557658
const int,
558659
const bool,
559660
const double) const;
560-
661+
template void PW_Basis_K::convolution_gpu<float>(const int ik,
662+
const int size,
663+
const std::complex<float>* input,
664+
const float* input1,
665+
std::complex<float>* output,
666+
const bool add,
667+
const float factor) const;
668+
template void PW_Basis_K::convolution_gpu<double>(const int ik,
669+
const int size,
670+
const std::complex<double>* input,
671+
const double* input1,
672+
std::complex<double>* output,
673+
const bool add,
674+
const double factor) const;
561675
#endif
676+
template void PW_Basis_K::convolution_cpu<float>(const int ik,
677+
const int size,
678+
const std::complex<float>* input,
679+
const float* input1,
680+
std::complex<float>* output,
681+
const bool add,
682+
const float factor) const;
683+
template void PW_Basis_K::convolution_cpu<double>(const int ik,
684+
const int size,
685+
const std::complex<double>* input,
686+
const double* input1,
687+
std::complex<double>* output,
688+
const bool add,
689+
const double factor) const;
562690

563691
template void PW_Basis_K::real2recip<float>(const float* in,
564692
std::complex<float>* out,

source/source_pw/module_pwdft/operator_pw/veff_pw.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,15 @@ void Veff<OperatorPW<T, Device>>::act(
9696
{
9797
for (int ib = 0; ib < nbands; ib += npol)
9898
{
99-
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
99+
wfcpw->convolution<T, Device>(this->ik,
100+
this->veff_col,
101+
tmpsi_in,
102+
this->veff + current_spin * this->veff_col,
103+
tmhpsi,
104+
true);
100105
// NOTICE: when MPI threads are larger than the number of Z grids
101106
// veff would contain nothing, and nothing should be done in real space
102107
// but the 3DFFT can not be skipped, it will cause hanging
103-
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
104-
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
105108
tmhpsi += psi_offset;
106109
tmpsi_in += psi_offset;
107110
}

0 commit comments

Comments
 (0)