44#include " module_base/vector3.h"
55#include " module_basis/module_pw/pw_basis_k.h"
66#include " module_basis/module_pw/module_fft/fft_bundle.h"
7+ #include " module_hamilt_pw/hamilt_pwdft/kernels/veff_op.h"
78#include " pw_test.h"
89#include < complex>
910#include < vector>
@@ -26,20 +27,24 @@ class PW_BASIS_K_BATCH_GPU_TEST : public ::testing::Test
2627 std::vector<std::complex <double >> psir; // Device memory for output psir(R space) data
2728 std::complex <double >* d_psir = nullptr ; // Device memory for psir data
2829 std::complex <double >* d_psir_batch = nullptr ; // Device memory for psir output batch data
30+ std::vector<double > veff; // Device memory for effective potential
31+ double * d_veff = nullptr ; // Device memory for effective potential
2932 ModulePW::FFT_Bundle ft_gpu; // FFT bundle for 3D FFT operations on GPU
3033 ModulePW::FFT_Bundle ft_gpu_batch; // FFT bundle for 3D FFT operations on batch-GPU
34+
3135 void SetUp () override
3236 {
3337 box_index.resize (npwk);
3438 psig.resize (npwk * batch);
3539 psir.resize (nxyz * batch);
40+ veff.resize (nxyz * batch);
3641
3742 resize_memory_int_gpu_op ()(d_box_index, npwk);
3843 resize_memory_complex_gpu_op ()(d_psig, npwk * batch);
3944 resize_memory_complex_gpu_op ()(d_psir, nxyz * batch);
4045 resize_memory_complex_gpu_op ()(d_psig_batch, npwk * batch);
4146 resize_memory_complex_gpu_op ()(d_psir_batch, nxyz * batch);
42-
47+ resize_memory_double_gpu_op ()(d_veff, nxyz * batch);
4348 // Initialize the box_index and input with some values
4449 int idx = 0 ;
4550 std::generate_n (box_index.begin (), npwk, [&idx] { return idx * idx++; });
@@ -53,9 +58,16 @@ class PW_BASIS_K_BATCH_GPU_TEST : public ::testing::Test
5358 idx ++;
5459 return std::complex <double >(std::sqrt (idx), 1.0 /(idx+1 ));
5560 });
61+ idx=0 ;
62+ std::generate_n (veff.begin (), nxyz * batch, [&idx]
63+ {
64+ idx++;
65+ return (1.0 /(idx+1 )+std::sqrt (idx));
66+ });
5667 synchronize_memory_int_h2d_op ()(d_box_index, box_index.data (), npwk);
5768 synchronize_memory_complex_h2d_op ()(d_psig, psig.data (), npwk * batch);
5869 synchronize_memory_complex_h2d_op ()(d_psig_batch, psig.data (), npwk * batch);
70+ synchronize_memory_double_h2d_op ()(d_veff, veff.data (), nxyz * batch);
5971 // Initialize the box_index with some values
6072 ft_gpu.setfft (" gpu" , " double" );
6173 ft_gpu.initfft (10 , 10 , 10 , 1 , 1 , 1 , 1 , 1 , 1 );
@@ -72,6 +84,9 @@ class PW_BASIS_K_BATCH_GPU_TEST : public ::testing::Test
7284 delete_memory_int_gpu_op ()(d_box_index);
7385 delete_memory_complex_gpu_op ()(d_psig);
7486 delete_memory_complex_gpu_op ()(d_psir);
87+ delete_memory_complex_gpu_op ()(d_psig_batch);
88+ delete_memory_complex_gpu_op ()(d_psir_batch);
89+ delete_memory_double_gpu_op ()(d_veff);
7590 ft_gpu.clear ();
7691 ft_gpu_batch.clear ();
7792 }
@@ -120,20 +135,131 @@ TEST_F(PW_BASIS_K_BATCH_GPU_TEST,convulution)
120135 }
121136
122137 // STEP 3 perform the 3D FFT forward operation
123-
124138 for (int i=0 ;i<batch;i++)
125139 {
126- ft_gpu.fft3D_forward (d_psir + i *nxyz, d_psir + i *nxyz);
140+ ft_gpu.fft3D_backward (d_psir + i *nxyz, d_psir + i *nxyz);
127141
128142 }
129- ft_gpu_batch.fft3D_forward (d_psir_batch, d_psir_batch );
143+ ft_gpu_batch.fft3D_backward (d_psir_batch, d_psir_batch );
130144 synchronize_memory_complex_d2h_op ()(compute_psir.data (),d_psir , nxyz * batch);
131145 synchronize_memory_complex_d2h_op ()(compute_psir_batch.data (), d_psir_batch,nxyz * batch);
132146
133147 for (int i = 0 ; i < nxyz *batch ; ++i)
134148 {
135- EXPECT_NEAR (compute_psir[i].real (), compute_psir_batch[i].real (), 1e-4 );
136- EXPECT_NEAR (compute_psir[i].imag (), compute_psir_batch[i].imag (), 1e-4 );
149+ EXPECT_NEAR (compute_psir[i].real (), compute_psir_batch[i].real (), 1e-7 );
150+ EXPECT_NEAR (compute_psir[i].imag (), compute_psir_batch[i].imag (), 1e-7 );
151+ }
152+ // STEP 4 set the reciprocal to real space operation
153+ for (int i=0 ; i< batch; i++)
154+ {
155+ ModulePW::set_recip_to_real_output_op<double ,
156+ base_device::DEVICE_GPU>()
157+ (
158+ nxyz,
159+ true ,
160+ 1.0 ,
161+ d_psir + i * nxyz,
162+ d_psir + i * nxyz
163+ );
164+ }
165+ ModulePW::set_recip_to_real_output_op<double ,
166+ base_device::DEVICE_GPU>()
167+ (
168+ nxyz,
169+ true ,
170+ 1.0 ,
171+ d_psir_batch,
172+ d_psir_batch,
173+ batch
174+ );
175+ synchronize_memory_complex_d2h_op ()(compute_psir.data (),d_psir , nxyz * batch);
176+ synchronize_memory_complex_d2h_op ()(compute_psir_batch.data (), d_psir_batch,nxyz * batch);
177+ for (int i = 0 ; i < nxyz *batch ; ++i)
178+ {
179+ EXPECT_NEAR (compute_psir[i].real (), compute_psir_batch[i].real (), 1e-7 );
180+ EXPECT_NEAR (compute_psir[i].imag (), compute_psir_batch[i].imag (), 1e-7 );
181+ }
182+
183+ // STEP 5 use veff_pw operation to compute
184+ const base_device::DEVICE_GPU * dev_gpu;
185+ for (int i = 0 ; i < batch; ++i)
186+ {
187+ hamilt::veff_pw_op<double ,
188+ base_device::DEVICE_GPU>()
189+ (
190+ dev_gpu,
191+ nxyz,
192+ d_psir + i * nxyz,
193+ d_veff + i * nxyz
194+ );
195+ }
196+ hamilt::veff_pw_op<double ,
197+ base_device::DEVICE_GPU>()
198+ (
199+ dev_gpu,
200+ nxyz,
201+ d_psir_batch,
202+ d_veff,
203+ batch
204+ );
205+
206+ synchronize_memory_complex_d2h_op ()(compute_psir.data (),d_psir , nxyz * batch);
207+ synchronize_memory_complex_d2h_op ()(compute_psir_batch.data (), d_psir_batch,nxyz * batch);
208+
209+ for (int i = 0 ; i < nxyz *batch ; ++i)
210+ {
211+ EXPECT_NEAR (compute_psir[i].real (), compute_psir_batch[i].real (), 1e-7 );
212+ EXPECT_NEAR (compute_psir[i].imag (), compute_psir_batch[i].imag (), 1e-7 );
213+ }
214+
215+ // STEP 6 perform the 3D FFT backward operation
216+ std::vector<std::complex <double >> compute_psig (nxyz * batch);
217+ std::vector<std::complex <double >> compute_psig_batch (nxyz * batch);
218+ for (int i=0 ;i<batch;i++)
219+ {
220+ ft_gpu.fft3D_forward (d_psir + i *nxyz, d_psir + i *nxyz);
221+ }
222+ ft_gpu_batch.fft3D_forward (d_psir_batch, d_psir_batch);
223+ synchronize_memory_complex_d2h_op ()(compute_psig.data (),d_psir , nxyz * batch);
224+ synchronize_memory_complex_d2h_op ()(compute_psig_batch.data (), d_psir_batch,nxyz * batch);
225+ for (int i = 0 ; i < nxyz *batch ; ++i)
226+ {
227+ EXPECT_NEAR (compute_psig[i].real (), compute_psig_batch[i].real (), 1e-7 );
228+ EXPECT_NEAR (compute_psig[i].imag (), compute_psig_batch[i].imag (), 1e-7 );
229+ }
230+
231+ // STEP 7 check the output psig has been
232+ for (int i =0 ; i< batch;i++)
233+ {
234+ ModulePW::set_real_to_recip_output_op<double ,
235+ base_device::DEVICE_GPU>()
236+ (
237+ npwk,
238+ nxyz,
239+ true ,
240+ 1.0 ,
241+ d_box_index,
242+ d_psir + i * nxyz,
243+ d_psig + i * npwk
244+ );
245+ }
246+ ModulePW::set_real_to_recip_output_op<double ,
247+ base_device::DEVICE_GPU>()
248+ (
249+ npwk,
250+ nxyz,
251+ true ,
252+ 1.0 ,
253+ d_box_index,
254+ d_psir_batch,
255+ d_psig_batch,
256+ batch
257+ );
258+ synchronize_memory_complex_d2h_op ()(compute_psig.data (),d_psig , npwk * batch);
259+ synchronize_memory_complex_d2h_op ()(compute_psig_batch.data (), d_psig_batch,npwk * batch);
260+ for (int i = 0 ; i < npwk *batch ; ++i)
261+ {
262+ EXPECT_NEAR (compute_psig[i].real (), compute_psig_batch[i].real (), 1e-7 );
263+ EXPECT_NEAR (compute_psig[i].imag (), compute_psig_batch[i].imag (), 1e-7 );
137264 }
138-
139265}
0 commit comments