Skip to content

Commit 40e73fd

Browse files
committed
add change on pw_basis_k
1 parent dedc0e1 commit 40e73fd

File tree

5 files changed

+20
-20
lines changed

5 files changed

+20
-20
lines changed

source/source_basis/module_pw/pw_transform_k.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
366366
auto* auxg = this->fft_bundle.get_auxg_data<double>();
367367
auto* auxr=this->fft_bundle.get_auxr_data<double>();
368368

369-
memset(auxg, 0, this->nst * this->nz * 2 * 8);
369+
memset(auxg, 0, this->nst * this->nz * 2 * 8);
370370
const int startig = ik * this->npwk_max;
371371
const int npwk = this->npwk[ik];
372372

@@ -385,28 +385,28 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
385385
this->gathers_scatterp(auxg, auxr);
386386

387387
this->fft_bundle.fftxybac(auxr, auxr);
388+
389+
#ifdef _OPENMP
390+
#pragma omp parallel for simd schedule(static) aligned(auxr, input1: 64)
391+
#endif
388392
for (int ir = 0; ir < size; ir++)
389393
{
390394
auxr[ir] *= input1[ir];
391395
}
392-
393396
// 3d fft
394397
this->fft_bundle.fftxyfor(auxr, auxr);
395398

396399
this->gatherp_scatters(auxr, auxg);
397400

398401
this->fft_bundle.fftzfor(auxg, auxg);
399402
// copy the result from the auxr to the out ,while consider the add
400-
if (add)
401-
{
402-
double tmpfac = factor / double(this->nxyz);
403+
double tmpfac = factor / double(this->nxyz);
403404
#ifdef _OPENMP
404405
#pragma omp parallel for schedule(static, 4096 / sizeof(double))
405406
#endif
406-
for (int igl = 0; igl < npwk; ++igl)
407-
{
408-
output[igl] += tmpfac * auxg[this->igl2isz_k[igl + startig]];
409-
}
407+
for (int igl = 0; igl < npwk; ++igl)
408+
{
409+
output[igl] += tmpfac * auxg[this->igl2isz_k[igl + startig]];
410410
}
411411
ModuleBase::timer::tick(this->classname, "convolution");
412412
}

source/source_io/cal_ldos.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ void stm_mode_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,
140140

141141
for (int ib = 0; ib < nbands; ib++)
142142
{
143-
pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik);
143+
pelec->basis->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik);
144144

145145
const double eigenval = (pelec->ekb(ik, ib) - efermi) * ModuleBase::Ry_to_eV;
146146
double weight = en > 0 ? pelec->klist->wk[ik] - pelec->wg(ik, ib) : pelec->wg(ik, ib);
@@ -210,7 +210,7 @@ void ldos_mode_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,
210210

211211
for (int ib = 0; ib < nbands; ib++)
212212
{
213-
pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik);
213+
pelec->basis->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik);
214214
const double weight = pelec->klist->wk[ik] / ucell.omega;
215215

216216
for (int ir = 0; ir < pelec->basis->nrxx; ir++)

source/source_io/cal_mlkedf_descriptors.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ void Cal_MLKEDF_Descriptors::getF_KS(
472472
wfcr[ig] = psi->operator()(ibnd, ig) * std::complex<double>(0.0, fact);
473473
}
474474

475-
pw_psi->recip2real(wfcr, wfcr, ik);
475+
pw_psi->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfcr, wfcr, ik);
476476

477477
for (int ir = 0; ir < this->nx; ++ir)
478478
{

source/source_io/get_wf_lcao.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,
179179
// Calculate real-space wave functions
180180
psi_g.fix_k(is);
181181
std::vector<std::complex<double>> wfc_r(pw_wfc->nrxx);
182-
pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), is);
182+
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), is);
183183

184184
// Extract real and imaginary parts
185185
std::vector<double> wfc_real(pw_wfc->nrxx);
@@ -399,7 +399,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,
399399

400400
// Calculate real-space wave functions
401401
std::vector<std::complex<double>> wfc_r(pw_wfc->nrxx);
402-
pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), ik);
402+
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), ik);
403403

404404
// Extract real and imaginary parts
405405
std::vector<double> wfc_real(pw_wfc->nrxx);
@@ -551,7 +551,7 @@ void Get_wf_lcao::set_pw_wfc(const ModulePW::PW_Basis_K* pw_wfc,
551551
}
552552

553553
// call FFT
554-
pw_wfc->real2recip(Porter.data(), &wfc_g(ib, 0), ik);
554+
pw_wfc->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(Porter.data(), &wfc_g(ib, 0), ik);
555555
}
556556

557557
#ifdef __MPI

source/source_io/unk_overlap_pw.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
9393
}
9494

9595
// (3) calculate the overlap in ik_L and ik_R
96-
wfcpw->real2recip(psi_r, psi_r, ik_R);
96+
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_r, psi_r, ik_R);
9797

9898
for (int ig = 0; ig < evc->get_ngk(ik_R); ig++)
9999
{
@@ -197,8 +197,8 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho
197197

198198
// (2) fft and get value
199199
rhopw->recip2real(phase, phase);
200-
wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_up, ik_L);
201-
wfcpw->recip2real(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L);
200+
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, 0), psi_up, ik_L);
201+
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L);
202202

203203
for (int ir = 0; ir < wfcpw->nrxx; ir++)
204204
{
@@ -207,8 +207,8 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho
207207
}
208208

209209
// (3) calculate the overlap in ik_L and ik_R
210-
wfcpw->real2recip(psi_up, psi_up, ik_L);
211-
wfcpw->real2recip(psi_down, psi_down, ik_L);
210+
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_up, psi_up, ik_L);
211+
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_down, psi_down, ik_L);
212212

213213
for (int i = 0; i < PARAM.globalv.npol; i++)
214214
{

0 commit comments

Comments
 (0)