Skip to content

Commit 82b8590

Browse files
committed
add convolution for nspin_4
1 parent ed64543 commit 82b8590

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

source/source_basis/module_pw/module_fft/fft_cuda.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ template <>
1616
void FFT_CUDA<float>::setupFFT()
1717
{
1818
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
19-
resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz);
19+
resmem_cd_op()(this->c_auxr_3d, 2*this->nx * this->ny * this->nz);
2020
}
2121
template <>
2222
void FFT_CUDA<double>::setupFFT()
2323
{
2424
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
25-
resmem_zd_op()(this->z_auxr_3d, this->nx * this->ny * this->nz);
25+
resmem_zd_op()(this->z_auxr_3d, 2*this->nx * this->ny * this->nz);
2626
}
2727
template <>
2828
void FFT_CUDA<float>::cleanFFT()

source/source_basis/module_pw/pw_transform_convolution.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,18 @@ void PW_Basis_K::convolution_gpu(const int ik,
189189
assert(this->gamma_only == false);
190190
base_device::DEVICE_GPU* gpu_ctx;
191191
// memset the auxr of 0 in the auxr,here the len of the auxr is nxyz
192-
base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
193-
tmp,
194-
0,
195-
2*this->nxyz);
196192
const int startig = ik * this->npwk_max;
197193
const int npw_k = this->npwk[ik];
198-
auto *auxg = tmp;
199-
auto *auxg1 = &tmp[size];
194+
auto *auxg = this->fft_bundle.get_auxr_3d_data<FPTYPE>();
195+
auto *auxg1 = &this->fft_bundle.get_auxr_3d_data<FPTYPE>()[size];
196+
base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
197+
auxg,
198+
0,
199+
this->nxyz);
200+
base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
201+
auxg1,
202+
0,
203+
this->nxyz);
200204
set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
201205
this->ig2ixyz_k + startig,
202206
input,
@@ -215,14 +219,14 @@ void PW_Basis_K::convolution_gpu(const int ik,
215219
// use 3d fft backward
216220
set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
217221
this->nxyz,
218-
add,
222+
true,
219223
factor,
220224
this->ig2ixyz_k + startig,
221225
auxg,
222226
output);
223227
set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
224228
this->nxyz,
225-
add,
229+
true,
226230
factor,
227231
this->ig2ixyz_k + startig,
228232
auxg1,

source/source_basis/module_pw/pw_transform_k.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include "source_basis/module_pw/kernels/pw_op.h"
33
#include "pw_basis_k.h"
44
#include "pw_gatherscatter.h"
5-
#include "source_pw/module_pwdft/kernels/veff_op.h"
65
namespace ModulePW
76
{
87

0 commit comments

Comments
 (0)