Skip to content

Commit 848a52a

Browse files
committed
revert convulution
1 parent 40e73fd commit 848a52a

File tree

4 files changed

+10
-99
lines changed

4 files changed

+10
-99
lines changed

source/source_basis/module_pw/pw_basis_k.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,9 @@ class PW_Basis_K : public PW_Basis
135135
const int ik,
136136
const bool add = false,
137137
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)
138-
template <typename FPTYPE, typename Device>
139-
void convolution(const Device* ctx,
140-
const int ik,
141-
const int size,
142-
const std::complex<FPTYPE>* input,
143-
const FPTYPE* input1,
144-
std::complex<FPTYPE>* output,
145-
const bool add = false,
146-
const FPTYPE factor =1.0) const ;
147138
#if defined(__DSP)
148139
template <typename FPTYPE, typename Device>
149-
void convolution_dsp(const Device* ctx,
140+
void convolution(const Device* ctx,
150141
const int ik,
151142
const int size,
152143
const std::complex<FPTYPE>* input,

source/source_basis/module_pw/pw_transform_k.cpp

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -337,79 +337,6 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/,
337337
this->recip2real(in, out, ik, add, factor);
338338
#endif
339339
}
340-
template <>
341-
void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
342-
const int ik,
343-
const int size,
344-
const std::complex<float>* input,
345-
const float* input1,
346-
std::complex<float>* output,
347-
const bool add,
348-
const float factor) const
349-
{
350-
}
351-
352-
template <>
353-
void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
354-
const int ik,
355-
const int size,
356-
const std::complex<double>* input,
357-
const double* input1,
358-
std::complex<double>* output,
359-
const bool add,
360-
const double factor) const
361-
{
362-
ModuleBase::timer::tick(this->classname, "convolution");
363-
assert(this->gamma_only == false);
364-
// ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data<double>(), this->nst * this->nz);
365-
// memset the auxr of 0 in the auxr,here the len of the auxr is nxyz
366-
auto* auxg = this->fft_bundle.get_auxg_data<double>();
367-
auto* auxr=this->fft_bundle.get_auxr_data<double>();
368-
369-
memset(auxg, 0, this->nst * this->nz * 2 * 8);
370-
const int startig = ik * this->npwk_max;
371-
const int npwk = this->npwk[ik];
372-
373-
// copy the mapping form the type of stick to the 3dfft
374-
#ifdef _OPENMP
375-
#pragma omp parallel for schedule(static, 4096 / sizeof(double))
376-
#endif
377-
for (int igl = 0; igl < npwk; ++igl)
378-
{
379-
auxg[this->igl2isz_k[igl + startig]] = input[igl];
380-
}
381-
382-
// use 3d fft backward
383-
this->fft_bundle.fftzbac(auxg, auxg);
384-
385-
this->gathers_scatterp(auxg, auxr);
386-
387-
this->fft_bundle.fftxybac(auxr, auxr);
388-
389-
#ifdef _OPENMP
390-
#pragma omp parallel for simd schedule(static) aligned(auxr, input1: 64)
391-
#endif
392-
for (int ir = 0; ir < size; ir++)
393-
{
394-
auxr[ir] *= input1[ir];
395-
}
396-
// 3d fft
397-
this->fft_bundle.fftxyfor(auxr, auxr);
398-
399-
this->gatherp_scatters(auxr, auxg);
400-
401-
this->fft_bundle.fftzfor(auxg, auxg);
402-
// copy the result from the auxr to the out ,while consider the add
403-
double tmpfac = factor / double(this->nxyz);
404-
#ifdef _OPENMP
405-
#pragma omp parallel for schedule(static, 4096 / sizeof(double))
406-
#endif
407-
for (int igl = 0; igl < npwk; ++igl)
408-
{
409-
output[igl] += tmpfac * auxg[this->igl2isz_k[igl + startig]];
410-
}
411-
ModuleBase::timer::tick(this->classname, "convolution");
412-
}
413340

414341
#if (defined(__CUDA) || defined(__ROCM))
415342
template <>

source/source_basis/module_pw/pw_transform_k_dsp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex<double>* in,
9191
}
9292
}
9393
template <>
94-
void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx,
94+
void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
9595
const int ik,
9696
const int size,
9797
const std::complex<float>* input,
@@ -103,7 +103,7 @@ void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx,
103103
}
104104

105105
template <>
106-
void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx,
106+
void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
107107
const int ik,
108108
const int size,
109109
const std::complex<double>* input,

source/source_pw/module_pwdft/operator_pw/veff_pw.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void Veff<OperatorPW<T, Device>>::act(
6161
ModulePW::FFT_Guard guard(wfcpw->fft_bundle);
6262
for (int ib = 0; ib < nbands; ib += npol)
6363
{
64-
wfcpw->convolution_dsp(this->ctx,
64+
wfcpw->convolution(this->ctx,
6565
this->ik,
6666
this->veff_col,
6767
tmpsi_in,
@@ -96,19 +96,12 @@ void Veff<OperatorPW<T, Device>>::act(
9696
{
9797
for (int ib = 0; ib < nbands; ib += npol)
9898
{
99-
wfcpw->convolution(this->ctx,
100-
this->ik,
101-
this->veff_col,
102-
tmpsi_in,
103-
this->veff + current_spin * this->veff_col,
104-
tmhpsi,
105-
true);
106-
// wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
107-
// // NOTICE: when MPI threads are larger than the number of Z grids
108-
// // veff would contain nothing, and nothing should be done in real space
109-
// // but the 3DFFT can not be skipped, it will cause hanging
110-
// veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
111-
// wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
99+
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
100+
// NOTICE: when MPI threads are larger than the number of Z grids
101+
// veff would contain nothing, and nothing should be done in real space
102+
// but the 3DFFT can not be skipped, it will cause hanging
103+
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
104+
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
112105
tmhpsi += psi_offset;
113106
tmpsi_in += psi_offset;
114107
}

0 commit comments

Comments
 (0)