Skip to content

Commit dedc0e1

Browse files
committed
add convulution for veff
1 parent 1bee98c commit dedc0e1

File tree

4 files changed

+99
-10
lines changed

4 files changed

+99
-10
lines changed

source/source_basis/module_pw/pw_basis_k.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ class PW_Basis_K : public PW_Basis
135135
const int ik,
136136
const bool add = false,
137137
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)
138-
#if defined(__DSP)
139138
template <typename FPTYPE, typename Device>
140139
void convolution(const Device* ctx,
141140
const int ik,
@@ -145,6 +144,16 @@ class PW_Basis_K : public PW_Basis
145144
std::complex<FPTYPE>* output,
146145
const bool add = false,
147146
const FPTYPE factor =1.0) const ;
147+
#if defined(__DSP)
148+
template <typename FPTYPE, typename Device>
149+
void convolution_dsp(const Device* ctx,
150+
const int ik,
151+
const int size,
152+
const std::complex<FPTYPE>* input,
153+
const FPTYPE* input1,
154+
std::complex<FPTYPE>* output,
155+
const bool add = false,
156+
const FPTYPE factor =1.0) const ;
148157

149158
template <typename FPTYPE>
150159
void real2recip_dsp(const std::complex<FPTYPE>* in,

source/source_basis/module_pw/pw_transform_k.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,79 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/,
337337
this->recip2real(in, out, ik, add, factor);
338338
#endif
339339
}
340+
template <>
341+
void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
342+
const int ik,
343+
const int size,
344+
const std::complex<float>* input,
345+
const float* input1,
346+
std::complex<float>* output,
347+
const bool add,
348+
const float factor) const
349+
{
350+
}
351+
352+
template <>
353+
void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
354+
const int ik,
355+
const int size,
356+
const std::complex<double>* input,
357+
const double* input1,
358+
std::complex<double>* output,
359+
const bool add,
360+
const double factor) const
361+
{
362+
ModuleBase::timer::tick(this->classname, "convolution");
363+
assert(this->gamma_only == false);
364+
// ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data<double>(), this->nst * this->nz);
365+
// memset the auxr of 0 in the auxr,here the len of the auxr is nxyz
366+
auto* auxg = this->fft_bundle.get_auxg_data<double>();
367+
auto* auxr=this->fft_bundle.get_auxr_data<double>();
368+
369+
memset(auxg, 0, this->nst * this->nz * 2 * 8);
370+
const int startig = ik * this->npwk_max;
371+
const int npwk = this->npwk[ik];
372+
373+
// copy the mapping form the type of stick to the 3dfft
374+
#ifdef _OPENMP
375+
#pragma omp parallel for schedule(static, 4096 / sizeof(double))
376+
#endif
377+
for (int igl = 0; igl < npwk; ++igl)
378+
{
379+
auxg[this->igl2isz_k[igl + startig]] = input[igl];
380+
}
381+
382+
// use 3d fft backward
383+
this->fft_bundle.fftzbac(auxg, auxg);
384+
385+
this->gathers_scatterp(auxg, auxr);
386+
387+
this->fft_bundle.fftxybac(auxr, auxr);
388+
for (int ir = 0; ir < size; ir++)
389+
{
390+
auxr[ir] *= input1[ir];
391+
}
392+
393+
// 3d fft
394+
this->fft_bundle.fftxyfor(auxr, auxr);
395+
396+
this->gatherp_scatters(auxr, auxg);
397+
398+
this->fft_bundle.fftzfor(auxg, auxg);
399+
// copy the result from the auxr to the out ,while consider the add
400+
if (add)
401+
{
402+
double tmpfac = factor / double(this->nxyz);
403+
#ifdef _OPENMP
404+
#pragma omp parallel for schedule(static, 4096 / sizeof(double))
405+
#endif
406+
for (int igl = 0; igl < npwk; ++igl)
407+
{
408+
output[igl] += tmpfac * auxg[this->igl2isz_k[igl + startig]];
409+
}
410+
}
411+
ModuleBase::timer::tick(this->classname, "convolution");
412+
}
340413

341414
#if (defined(__CUDA) || defined(__ROCM))
342415
template <>

source/source_basis/module_pw/pw_transform_k_dsp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex<double>* in,
9191
}
9292
}
9393
template <>
94-
void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
94+
void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx,
9595
const int ik,
9696
const int size,
9797
const std::complex<float>* input,
@@ -103,7 +103,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
103103
}
104104

105105
template <>
106-
void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
106+
void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx,
107107
const int ik,
108108
const int size,
109109
const std::complex<double>* input,

source/source_pw/module_pwdft/operator_pw/veff_pw.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void Veff<OperatorPW<T, Device>>::act(
6161
ModulePW::FFT_Guard guard(wfcpw->fft_bundle);
6262
for (int ib = 0; ib < nbands; ib += npol)
6363
{
64-
wfcpw->convolution(this->ctx,
64+
wfcpw->convolution_dsp(this->ctx,
6565
this->ik,
6666
this->veff_col,
6767
tmpsi_in,
@@ -96,12 +96,19 @@ void Veff<OperatorPW<T, Device>>::act(
9696
{
9797
for (int ib = 0; ib < nbands; ib += npol)
9898
{
99-
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
100-
// NOTICE: when MPI threads are larger than the number of Z grids
101-
// veff would contain nothing, and nothing should be done in real space
102-
// but the 3DFFT can not be skipped, it will cause hanging
103-
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
104-
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
99+
wfcpw->convolution(this->ctx,
100+
this->ik,
101+
this->veff_col,
102+
tmpsi_in,
103+
this->veff + current_spin * this->veff_col,
104+
tmhpsi,
105+
true);
106+
// wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
107+
// // NOTICE: when MPI threads are larger than the number of Z grids
108+
// // veff would contain nothing, and nothing should be done in real space
109+
// // but the 3DFFT can not be skipped, it will cause hanging
110+
// veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
111+
// wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
105112
tmhpsi += psi_offset;
106113
tmpsi_in += psi_offset;
107114
}

0 commit comments

Comments
 (0)