33#include " module_base/module_device/device.h"
44#include " module_base/vector3.h"
55#include " module_basis/module_pw/pw_basis_k.h"
6+ #include " module_basis/module_pw/module_fft/fft_bundle.h"
67#include " pw_test.h"
78#include < complex>
89#include < vector>
@@ -14,45 +15,132 @@ using namespace std;
1415class PW_BASIS_K_BATCH_GPU_TEST : public ::testing::Test
1516{
1617 public:
17- const int batch = 10 ; // Number of batches
18+ const int batch = 1 ; // Number of batches
1819 const int npwk = 30 ; // Number of planewaves
1920 const int nxyz = 1000 ; // Size of the 3D grid
2021 std::vector<int > box_index; // Index mapping for 3D grid
2122 int * d_box_index=nullptr ; // Device memory for box_index
22- std::vector<std::complex <double >> rhog; // Input data for the test,
23+ std::vector<std::complex <double >> rhog; // rhoG(K space) data for the test,
2324 std::complex <double >* d_rhog = nullptr ; // Device memory for rhoG data
24- std::vector<std::complex <double >> rhor = nullptr ; // Device memory for output rhoG data
25- std::complex <double >* d_rhor = nullptr ; // Device memory for output data
25+ std::complex <double >* d_rhog_batch = nullptr ; // Device memory for rhoG output data
26+ std::vector<std::complex <double >> rhor; // Device memory for output rhoR(R space) data
27+ std::complex <double >* d_rhor = nullptr ; // Device memory for rhoR data
28+ std::complex <double >* d_rhor_batch = nullptr ; // Device memory for rhoR output batch data
29+ ModulePW::FFT_Bundle ft_gpu; // FFT bundle for 3D FFT operations on GPU
30+ ModulePW::FFT_Bundle ft_gpu_batch; // FFT bundle for 3D FFT operations on batch-GPU
2631 void SetUp () override
2732 {
2833 box_index.resize (npwk);
29- rhog.resize (npwk);
34+ rhog.resize (npwk * batch);
35+ rhor.resize (nxyz * batch);
36+
3037 resize_memory_int_gpu_op ()(d_box_index, npwk);
31- resize_memory_complex_gpu_op ()(d_rhog, npwk);
38+ resize_memory_complex_gpu_op ()(d_rhog, npwk * batch);
39+ resize_memory_complex_gpu_op ()(d_rhor, nxyz * batch);
40+ resize_memory_complex_gpu_op ()(d_rhog_batch, npwk * batch);
41+ resize_memory_complex_gpu_op ()(d_rhor_batch, nxyz * batch);
42+
3243 // Initialize the box_index and input with some values
3344 int idx = 0 ;
3445 std::generate_n (box_index.begin (), npwk, [&idx] { return idx * idx++; });
3546 idx =0 ;
36- std::generate_n (rhog.begin (), npwk, [&idx] { return std::complex <double >(sqrt (idx), 1 /(idx+1 )); });
47+ int npwk = box_index.size ();
48+ // Initialize rhog with some complex values,it generates a complex number
49+ // with real part as sqrt(idx) and imaginary part as 1/(idx+1),
50+ // thus in different batches the values of rhog will be different.
51+ std::generate_n (rhog.begin (), npwk * batch, [&idx,npwk]
52+ {
53+ idx ++;
54+ return std::complex <double >(std::sqrt (idx), 1.0 /(idx+1 ));
55+ });
3756 synchronize_memory_int_h2d_op ()(d_box_index, box_index.data (), npwk);
38- synchronize_memory_complex_h2d_op ()(d_rhog, rhog.data (), npwk);
57+ synchronize_memory_complex_h2d_op ()(d_rhog, rhog.data (), npwk * batch);
58+ synchronize_memory_complex_h2d_op ()(d_rhog_batch, rhog.data (), npwk * batch);
3959 // Initialize the box_index with some values
40-
41- // resize_memory_int_gpu_op
60+ ft_gpu.setfft (" gpu" , " double" );
61+ ft_gpu.initfft (10 , 10 , 10 , 1 , 1 , 1 , 1 , 1 , 1 );
62+ ft_gpu.setupFFT ();
63+ ft_gpu_batch.setfft (" gpu" , " double" );
64+ ft_gpu_batch.initfft (10 , 10 , 10 , 1 , 1 , 1 , 1 , 1 , 1 );
65+ ft_gpu_batch.setupFFT ();
4266 }
4367 void TearDown () override
4468 {
4569 box_index.clear ();
4670 rhog.clear ();
71+ rhor.clear ();
4772 delete_memory_int_gpu_op ()(d_box_index);
4873 delete_memory_complex_gpu_op ()(d_rhog);
74+ delete_memory_complex_gpu_op ()(d_rhor);
75+ ft_gpu.clear ();
76+ ft_gpu_batch.clear ();
4977 }
5078};
5179
5280TEST_F (PW_BASIS_K_BATCH_GPU_TEST,convulution)
5381{
82+ // STEP 1 set the 3D FFT box operation for CPU
5483 for (int i = 0 ; i < npwk; ++i)
5584 {
5685 EXPECT_EQ (box_index[i], i * i);
5786 }
87+
88+ // STEP 2 check the input rhog has been
89+ // correctly mapped to the 3D grid
90+ std::vector<std::complex <double >> compute_rhor (nxyz * batch);
91+ std::vector<std::complex <double >> compute_rhor_batch (nxyz * batch);
92+ for (int i = 0 ; i< batch; i++)
93+ {
94+ ModulePW::set_3d_fft_box_op<double ,
95+ base_device::DEVICE_GPU>()
96+ (
97+ npwk,
98+ d_box_index,
99+ d_rhog + i * npwk,
100+ d_rhor + i * nxyz
101+ );
102+ synchronize_memory_complex_d2h_op ()(compute_rhor.data ()+i * nxyz, d_rhor + i *nxyz, nxyz);
103+ }
104+ ModulePW::set_3d_fft_box_op<double ,
105+ base_device::DEVICE_GPU>()
106+ (
107+ npwk,
108+ nxyz,
109+ d_box_index,
110+ d_rhog_batch,
111+ d_rhor_batch,
112+ batch
113+ );
114+
115+ synchronize_memory_complex_d2h_op ()(compute_rhor_batch.data (), d_rhor_batch,nxyz * batch);
116+ for (int i = 0 ; i < nxyz*batch ; ++i)
117+ {
118+ EXPECT_NEAR (compute_rhor[i].real (), compute_rhor_batch[i].real (), 1e-7 );
119+ EXPECT_NEAR (compute_rhor[i].imag (), compute_rhor_batch[i].imag (), 1e-7 );
120+ }
121+
122+ // STEP 3 perform the 3D FFT forward operation
123+ std::vector<std::complex <double >> compute_rhor1 (nxyz * batch,0 );
124+ std::vector<std::complex <double >> compute_rhor_batch1 (nxyz * batch,0 );
125+
126+ for (int i=0 ;i< batch; i++)
127+ {
128+ ft_gpu.fft3D_forward (d_rhor, d_rhor);
129+ synchronize_memory_complex_d2h_op ()(compute_rhor1.data (), d_rhor , nxyz);
130+ }
131+ // ft_gpu.fft3D_backward(d_rhor, d_rhor);
132+ for (int i=0 ;i< batch; i++)
133+ {
134+ ft_gpu_batch.fft3D_forward (d_rhor_batch, d_rhor_batch);
135+ synchronize_memory_complex_d2h_op ()(compute_rhor_batch1.data (), d_rhor_batch , nxyz);
136+ }
137+ // ft_gpu.fft3D_forward(d_rhor_batch, d_rhor_batch);
138+ // synchronize_memory_complex_d2h_op()(compute_rhor_batch.data(), d_rhor_batch,nxyz * batch);
139+
140+ for (int i = 0 ; i < nxyz *batch ; ++i)
141+ {
142+ EXPECT_NEAR (compute_rhor1[i].real (), compute_rhor_batch1[i].real (), 1e-4 );
143+ // EXPECT_NEAR(compute_rhor[i].imag(), compute_rhor_batch[i].imag(), 1e-4);
144+ }
145+
58146}
0 commit comments