Skip to content

Commit a987eb5

Browse files
committed
add change
1 parent 575192a commit a987eb5

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

source/source_basis/module_pw/module_fft/fft_cuda.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ template <>
1616
void FFT_CUDA<float>::setupFFT()
1717
{
1818
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
19-
resmem_cd_op()(this->c_auxr_3d, 2*this->nx * this->ny * this->nz);
19+
resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz);
2020
}
2121
template <>
2222
void FFT_CUDA<double>::setupFFT()
2323
{
2424
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
25-
resmem_zd_op()(this->z_auxr_3d, 2*this->nx * this->ny * this->nz);
25+
resmem_zd_op()(this->z_auxr_3d, this->nx * this->ny * this->nz);
2626
}
2727
template <>
2828
void FFT_CUDA<float>::cleanFFT()

source/source_basis/module_pw/pw_transform_convolution.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,9 @@ void PW_Basis_K::convolution_cpu(const int ik,
8989
memset(augx1, 0, this->nst * this->nz * sizeof(FPTYPE)*2);
9090
for (int igl = 0; igl < npwk; ++igl)
9191
{
92-
augr[this->igl2isz_k[igl + startig]] = input[igl];
93-
}
94-
for (int igl =0 ; igl < npwk ; ++igl)
95-
{
96-
augr1[this->igl2isz_k[igl + startig]] = input[igl+max_npw];
92+
const int idx=this->igl2isz_k[igl + startig];
93+
augr[idx] = input[igl];
94+
augr1[idx] = input[igl+max_npw];
9795
}
9896
// use 3d fft backward
9997
this->fft_bundle.fftzbac(augr, augr);
@@ -118,15 +116,12 @@ void PW_Basis_K::convolution_cpu(const int ik,
118116
#ifdef _OPENMP
119117
#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
120118
#endif
121-
for (int igl = 0; igl < npwk; ++igl)
122-
{
123-
output[igl] += tmpfac * augr[this->igl2isz_k[igl + startig]];
124-
}
125-
126-
for (int igl =0 ; igl < npwk ; ++igl)
127-
{
128-
output[igl+max_npw] += tmpfac * augr1[this->igl2isz_k[igl + startig]];
129-
}
119+
for (int igl = 0; igl < npwk; ++igl)
120+
{
121+
const int idx=this->igl2isz_k[igl + startig];
122+
output[igl] += tmpfac * augr[idx];
123+
output[igl+max_npw] += tmpfac * augr1[idx];
124+
}
130125
ModuleBase::timer::tick(this->classname, "convolution");
131126
}
132127

0 commit comments

Comments
 (0)