@@ -15,63 +15,63 @@ using namespace std;
1515class PW_BASIS_K_BATCH_GPU_TEST : public ::testing::Test
1616{
1717 public:
18- const int batch = 1 ; // Number of batches
18+ const int batch = 2 ; // Number of batches
1919 const int npwk = 30 ; // Number of planewaves
2020 const int nxyz = 1000 ; // Size of the 3D grid
2121 std::vector<int > box_index; // Index mapping for 3D grid
2222 int * d_box_index=nullptr ; // Device memory for box_index
23- std::vector<std::complex <double >> rhog ; // rhoG (K space) data for the test,
24- std::complex <double >* d_rhog = nullptr ; // Device memory for rhoG data
25- std::complex <double >* d_rhog_batch = nullptr ; // Device memory for rhoG output data
26- std::vector<std::complex <double >> rhor ; // Device memory for output rhoR (R space) data
27- std::complex <double >* d_rhor = nullptr ; // Device memory for rhoR data
28- std::complex <double >* d_rhor_batch = nullptr ; // Device memory for rhoR output batch data
23+ std::vector<std::complex <double >> psig ; // psig (K space) data for the test,
24+ std::complex <double >* d_psig = nullptr ; // Device memory for psig data
25+ std::complex <double >* d_psig_batch = nullptr ; // Device memory for psig output data
26+ std::vector<std::complex <double >> psir ; // Device memory for output psir (R space) data
27+ std::complex <double >* d_psir = nullptr ; // Device memory for psir data
28+ std::complex <double >* d_psir_batch = nullptr ; // Device memory for psir output batch data
2929 ModulePW::FFT_Bundle ft_gpu; // FFT bundle for 3D FFT operations on GPU
3030 ModulePW::FFT_Bundle ft_gpu_batch; // FFT bundle for 3D FFT operations on batch-GPU
3131 void SetUp () override
3232 {
3333 box_index.resize (npwk);
34- rhog .resize (npwk * batch);
35- rhor .resize (nxyz * batch);
34+ psig .resize (npwk * batch);
35+ psir .resize (nxyz * batch);
3636
3737 resize_memory_int_gpu_op ()(d_box_index, npwk);
38- resize_memory_complex_gpu_op ()(d_rhog , npwk * batch);
39- resize_memory_complex_gpu_op ()(d_rhor , nxyz * batch);
40- resize_memory_complex_gpu_op ()(d_rhog_batch , npwk * batch);
41- resize_memory_complex_gpu_op ()(d_rhor_batch , nxyz * batch);
38+ resize_memory_complex_gpu_op ()(d_psig , npwk * batch);
39+ resize_memory_complex_gpu_op ()(d_psir , nxyz * batch);
40+ resize_memory_complex_gpu_op ()(d_psig_batch , npwk * batch);
41+ resize_memory_complex_gpu_op ()(d_psir_batch , nxyz * batch);
4242
4343 // Initialize the box_index and input with some values
4444 int idx = 0 ;
4545 std::generate_n (box_index.begin (), npwk, [&idx] { return idx * idx++; });
4646 idx =0 ;
4747 int npwk = box_index.size ();
48- // Initialize rhog with some complex values,it generates a complex number
48+ // Initialize psig with some complex values,it generates a complex number
4949 // with real part as sqrt(idx) and imaginary part as 1/(idx+1),
50- // thus in different batches the values of rhog will be different.
51- std::generate_n (rhog .begin (), npwk * batch, [&idx,npwk]
50+ // thus in different batches the values of psig will be different.
51+ std::generate_n (psig .begin (), npwk * batch, [&idx,npwk]
5252 {
5353 idx ++;
5454 return std::complex <double >(std::sqrt (idx), 1.0 /(idx+1 ));
5555 });
5656 synchronize_memory_int_h2d_op ()(d_box_index, box_index.data (), npwk);
57- synchronize_memory_complex_h2d_op ()(d_rhog, rhog .data (), npwk * batch);
58- synchronize_memory_complex_h2d_op ()(d_rhog_batch, rhog .data (), npwk * batch);
57+ synchronize_memory_complex_h2d_op ()(d_psig, psig .data (), npwk * batch);
58+ synchronize_memory_complex_h2d_op ()(d_psig_batch, psig .data (), npwk * batch);
5959 // Initialize the box_index with some values
6060 ft_gpu.setfft (" gpu" , " double" );
6161 ft_gpu.initfft (10 , 10 , 10 , 1 , 1 , 1 , 1 , 1 , 1 );
6262 ft_gpu.setupFFT ();
63- ft_gpu_batch.setfft (" gpu " , " double" );
63+ ft_gpu_batch.setfft (" gpu_batch " , " double" );
6464 ft_gpu_batch.initfft (10 , 10 , 10 , 1 , 1 , 1 , 1 , 1 , 1 );
6565 ft_gpu_batch.setupFFT ();
6666 }
6767 void TearDown () override
6868 {
6969 box_index.clear ();
70- rhog .clear ();
71- rhor .clear ();
70+ psig .clear ();
71+ psir .clear ();
7272 delete_memory_int_gpu_op ()(d_box_index);
73- delete_memory_complex_gpu_op ()(d_rhog );
74- delete_memory_complex_gpu_op ()(d_rhor );
73+ delete_memory_complex_gpu_op ()(d_psig );
74+ delete_memory_complex_gpu_op ()(d_psir );
7575 ft_gpu.clear ();
7676 ft_gpu_batch.clear ();
7777 }
@@ -85,62 +85,55 @@ TEST_F(PW_BASIS_K_BATCH_GPU_TEST,convulution)
8585 EXPECT_EQ (box_index[i], i * i);
8686 }
8787
88- // STEP 2 check the input rhog has been
88+ // STEP 2 check the input psig has been
8989 // correctly mapped to the 3D grid
90- std::vector<std::complex <double >> compute_rhor (nxyz * batch);
91- std::vector<std::complex <double >> compute_rhor_batch (nxyz * batch);
90+ std::vector<std::complex <double >> compute_psir (nxyz * batch);
91+ std::vector<std::complex <double >> compute_psir_batch (nxyz * batch);
9292 for (int i = 0 ; i< batch; i++)
9393 {
9494 ModulePW::set_3d_fft_box_op<double ,
9595 base_device::DEVICE_GPU>()
9696 (
9797 npwk,
9898 d_box_index,
99- d_rhog + i * npwk,
100- d_rhor + i * nxyz
99+ d_psig + i * npwk,
100+ d_psir + i * nxyz
101101 );
102- synchronize_memory_complex_d2h_op ()(compute_rhor .data ()+i * nxyz, d_rhor + i *nxyz, nxyz);
102+ synchronize_memory_complex_d2h_op ()(compute_psir .data ()+i * nxyz, d_psir + i *nxyz, nxyz);
103103 }
104104 ModulePW::set_3d_fft_box_op<double ,
105105 base_device::DEVICE_GPU>()
106106 (
107107 npwk,
108108 nxyz,
109109 d_box_index,
110- d_rhog_batch ,
111- d_rhor_batch ,
110+ d_psig_batch ,
111+ d_psir_batch ,
112112 batch
113113 );
114114
115- synchronize_memory_complex_d2h_op ()(compute_rhor_batch .data (), d_rhor_batch ,nxyz * batch);
115+ synchronize_memory_complex_d2h_op ()(compute_psir_batch .data (), d_psir_batch ,nxyz * batch);
116116 for (int i = 0 ; i < nxyz*batch ; ++i)
117117 {
118- EXPECT_NEAR (compute_rhor [i].real (), compute_rhor_batch [i].real (), 1e-7 );
119- EXPECT_NEAR (compute_rhor [i].imag (), compute_rhor_batch [i].imag (), 1e-7 );
118+ EXPECT_NEAR (compute_psir [i].real (), compute_psir_batch [i].real (), 1e-7 );
119+ EXPECT_NEAR (compute_psir [i].imag (), compute_psir_batch [i].imag (), 1e-7 );
120120 }
121121
122122 // STEP 3 perform the 3D FFT forward operation
123- std::vector<std::complex <double >> compute_rhor1 (nxyz * batch,0 );
124- std::vector<std::complex <double >> compute_rhor_batch1 (nxyz * batch,0 );
125123
126- for (int i=0 ;i< batch; i++)
124+ for (int i=0 ;i<batch;i++)
127125 {
128- ft_gpu.fft3D_forward (d_rhor, d_rhor );
129- synchronize_memory_complex_d2h_op ()(compute_rhor1. data (), d_rhor , nxyz);
126+ ft_gpu.fft3D_forward (d_psir + i *nxyz, d_psir + i *nxyz );
127+
130128 }
131- // ft_gpu.fft3D_backward(d_rhor, d_rhor);
132- for (int i=0 ;i< batch; i++)
133- {
134- ft_gpu_batch.fft3D_forward (d_rhor_batch, d_rhor_batch);
135- synchronize_memory_complex_d2h_op ()(compute_rhor_batch1.data (), d_rhor_batch , nxyz);
136- }
137- // ft_gpu.fft3D_forward(d_rhor_batch, d_rhor_batch);
138- // synchronize_memory_complex_d2h_op()(compute_rhor_batch.data(), d_rhor_batch,nxyz * batch);
129+ ft_gpu_batch.fft3D_forward (d_psir_batch, d_psir_batch );
130+ synchronize_memory_complex_d2h_op ()(compute_psir.data (),d_psir , nxyz * batch);
131+ synchronize_memory_complex_d2h_op ()(compute_psir_batch.data (), d_psir_batch,nxyz * batch);
139132
140133 for (int i = 0 ; i < nxyz *batch ; ++i)
141134 {
142- EXPECT_NEAR (compute_rhor1 [i].real (), compute_rhor_batch1 [i].real (), 1e-4 );
143- // EXPECT_NEAR(compute_rhor [i].imag(), compute_rhor_batch [i].imag(), 1e-4);
135+ EXPECT_NEAR (compute_psir [i].real (), compute_psir_batch [i].real (), 1e-4 );
136+ EXPECT_NEAR (compute_psir [i].imag (), compute_psir_batch [i].imag (), 1e-4 );
144137 }
145138
146139}
0 commit comments