Skip to content

Commit d15e1ea

Browse files
committed
add the 3d FFT forward
1 parent 12c1c40 commit d15e1ea

File tree

1 file changed

+42
-49
lines changed

1 file changed

+42
-49
lines changed

source/module_basis/module_pw/test_gpu/pw_basis_k_batch.cpp

Lines changed: 42 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,63 +15,63 @@ using namespace std;
1515
class PW_BASIS_K_BATCH_GPU_TEST : public ::testing::Test
1616
{
1717
public:
18-
const int batch = 1; // Number of batches
18+
const int batch = 2; // Number of batches
1919
const int npwk = 30; // Number of planewaves
2020
const int nxyz = 1000; // Size of the 3D grid
2121
std::vector<int> box_index; // Index mapping for 3D grid
2222
int* d_box_index=nullptr; // Device memory for box_index
23-
std::vector<std::complex<double>> rhog; // rhoG(K space) data for the test,
24-
std::complex<double>* d_rhog = nullptr; // Device memory for rhoG data
25-
std::complex<double>* d_rhog_batch = nullptr; // Device memory for rhoG output data
26-
std::vector<std::complex<double>> rhor; // Device memory for output rhoR(R space) data
27-
std::complex<double>* d_rhor = nullptr; // Device memory for rhoR data
28-
std::complex<double>* d_rhor_batch = nullptr; // Device memory for rhoR output batch data
23+
std::vector<std::complex<double>> psig; // psig(K space) data for the test,
24+
std::complex<double>* d_psig = nullptr; // Device memory for psig data
25+
std::complex<double>* d_psig_batch = nullptr; // Device memory for psig output data
26+
std::vector<std::complex<double>> psir; // Device memory for output psir(R space) data
27+
std::complex<double>* d_psir = nullptr; // Device memory for psir data
28+
std::complex<double>* d_psir_batch = nullptr; // Device memory for psir output batch data
2929
ModulePW::FFT_Bundle ft_gpu; // FFT bundle for 3D FFT operations on GPU
3030
ModulePW::FFT_Bundle ft_gpu_batch; // FFT bundle for 3D FFT operations on batch-GPU
3131
void SetUp() override
3232
{
3333
box_index.resize(npwk);
34-
rhog.resize(npwk * batch);
35-
rhor.resize(nxyz * batch);
34+
psig.resize(npwk * batch);
35+
psir.resize(nxyz * batch);
3636

3737
resize_memory_int_gpu_op()(d_box_index, npwk);
38-
resize_memory_complex_gpu_op()(d_rhog, npwk * batch);
39-
resize_memory_complex_gpu_op()(d_rhor, nxyz * batch);
40-
resize_memory_complex_gpu_op()(d_rhog_batch, npwk * batch);
41-
resize_memory_complex_gpu_op()(d_rhor_batch, nxyz * batch);
38+
resize_memory_complex_gpu_op()(d_psig, npwk * batch);
39+
resize_memory_complex_gpu_op()(d_psir, nxyz * batch);
40+
resize_memory_complex_gpu_op()(d_psig_batch, npwk * batch);
41+
resize_memory_complex_gpu_op()(d_psir_batch, nxyz * batch);
4242

4343
// Initialize the box_index and input with some values
4444
int idx = 0;
4545
std::generate_n(box_index.begin(), npwk, [&idx] { return idx * idx++; });
4646
idx =0;
4747
int npwk = box_index.size();
48-
// Initialize rhog with some complex values,it generates a complex number
48+
// Initialize psig with some complex values,it generates a complex number
4949
// with real part as sqrt(idx) and imaginary part as 1/(idx+1),
50-
// thus in different batches the values of rhog will be different.
51-
std::generate_n(rhog.begin(), npwk * batch, [&idx,npwk]
50+
// thus in different batches the values of psig will be different.
51+
std::generate_n(psig.begin(), npwk * batch, [&idx,npwk]
5252
{
5353
idx ++;
5454
return std::complex<double>(std::sqrt(idx), 1.0/(idx+1));
5555
});
5656
synchronize_memory_int_h2d_op()(d_box_index, box_index.data(), npwk);
57-
synchronize_memory_complex_h2d_op()(d_rhog, rhog.data(), npwk * batch);
58-
synchronize_memory_complex_h2d_op()(d_rhog_batch, rhog.data(), npwk * batch);
57+
synchronize_memory_complex_h2d_op()(d_psig, psig.data(), npwk * batch);
58+
synchronize_memory_complex_h2d_op()(d_psig_batch, psig.data(), npwk * batch);
5959
// Initialize the box_index with some values
6060
ft_gpu.setfft("gpu", "double");
6161
ft_gpu.initfft(10, 10, 10 , 1, 1, 1, 1, 1, 1);
6262
ft_gpu.setupFFT();
63-
ft_gpu_batch.setfft("gpu", "double");
63+
ft_gpu_batch.setfft("gpu_batch", "double");
6464
ft_gpu_batch.initfft(10, 10, 10 , 1, 1, 1, 1, 1, 1);
6565
ft_gpu_batch.setupFFT();
6666
}
6767
void TearDown() override
6868
{
6969
box_index.clear();
70-
rhog.clear();
71-
rhor.clear();
70+
psig.clear();
71+
psir.clear();
7272
delete_memory_int_gpu_op()(d_box_index);
73-
delete_memory_complex_gpu_op()(d_rhog);
74-
delete_memory_complex_gpu_op()(d_rhor);
73+
delete_memory_complex_gpu_op()(d_psig);
74+
delete_memory_complex_gpu_op()(d_psir);
7575
ft_gpu.clear();
7676
ft_gpu_batch.clear();
7777
}
@@ -85,62 +85,55 @@ TEST_F(PW_BASIS_K_BATCH_GPU_TEST,convulution)
8585
EXPECT_EQ(box_index[i], i * i);
8686
}
8787

88-
// STEP 2 check the input rhog has been
88+
// STEP 2 check the input psig has been
8989
// correctly mapped to the 3D grid
90-
std::vector<std::complex<double>> compute_rhor(nxyz * batch);
91-
std::vector<std::complex<double>> compute_rhor_batch(nxyz * batch);
90+
std::vector<std::complex<double>> compute_psir(nxyz * batch);
91+
std::vector<std::complex<double>> compute_psir_batch(nxyz * batch);
9292
for (int i = 0; i< batch; i++)
9393
{
9494
ModulePW::set_3d_fft_box_op<double,
9595
base_device::DEVICE_GPU>()
9696
(
9797
npwk,
9898
d_box_index,
99-
d_rhog + i * npwk,
100-
d_rhor + i * nxyz
99+
d_psig + i * npwk,
100+
d_psir + i * nxyz
101101
);
102-
synchronize_memory_complex_d2h_op()(compute_rhor.data()+i * nxyz, d_rhor + i *nxyz, nxyz);
102+
synchronize_memory_complex_d2h_op()(compute_psir.data()+i * nxyz, d_psir + i *nxyz, nxyz);
103103
}
104104
ModulePW::set_3d_fft_box_op<double,
105105
base_device::DEVICE_GPU>()
106106
(
107107
npwk,
108108
nxyz,
109109
d_box_index,
110-
d_rhog_batch,
111-
d_rhor_batch,
110+
d_psig_batch,
111+
d_psir_batch,
112112
batch
113113
);
114114

115-
synchronize_memory_complex_d2h_op()(compute_rhor_batch.data(), d_rhor_batch,nxyz * batch);
115+
synchronize_memory_complex_d2h_op()(compute_psir_batch.data(), d_psir_batch,nxyz * batch);
116116
for (int i = 0; i < nxyz*batch ; ++i)
117117
{
118-
EXPECT_NEAR(compute_rhor[i].real(), compute_rhor_batch[i].real(), 1e-7);
119-
EXPECT_NEAR(compute_rhor[i].imag(), compute_rhor_batch[i].imag(), 1e-7);
118+
EXPECT_NEAR(compute_psir[i].real(), compute_psir_batch[i].real(), 1e-7);
119+
EXPECT_NEAR(compute_psir[i].imag(), compute_psir_batch[i].imag(), 1e-7);
120120
}
121121

122122
// STEP 3 perform the 3D FFT forward operation
123-
std::vector<std::complex<double>> compute_rhor1(nxyz * batch,0);
124-
std::vector<std::complex<double>> compute_rhor_batch1(nxyz * batch,0);
125123

126-
for (int i=0;i< batch; i++)
124+
for (int i=0;i<batch;i++)
127125
{
128-
ft_gpu.fft3D_forward(d_rhor, d_rhor);
129-
synchronize_memory_complex_d2h_op()(compute_rhor1.data(), d_rhor , nxyz);
126+
ft_gpu.fft3D_forward(d_psir + i *nxyz, d_psir + i *nxyz);
127+
130128
}
131-
// ft_gpu.fft3D_backward(d_rhor, d_rhor);
132-
for (int i=0;i< batch; i++)
133-
{
134-
ft_gpu_batch.fft3D_forward(d_rhor_batch, d_rhor_batch);
135-
synchronize_memory_complex_d2h_op()(compute_rhor_batch1.data(), d_rhor_batch , nxyz);
136-
}
137-
// ft_gpu.fft3D_forward(d_rhor_batch, d_rhor_batch);
138-
// synchronize_memory_complex_d2h_op()(compute_rhor_batch.data(), d_rhor_batch,nxyz * batch);
129+
ft_gpu_batch.fft3D_forward(d_psir_batch, d_psir_batch );
130+
synchronize_memory_complex_d2h_op()(compute_psir.data(),d_psir , nxyz * batch);
131+
synchronize_memory_complex_d2h_op()(compute_psir_batch.data(), d_psir_batch,nxyz * batch);
139132

140133
for (int i = 0; i < nxyz *batch ; ++i)
141134
{
142-
EXPECT_NEAR(compute_rhor1[i].real(), compute_rhor_batch1[i].real(), 1e-4);
143-
// EXPECT_NEAR(compute_rhor[i].imag(), compute_rhor_batch[i].imag(), 1e-4);
135+
EXPECT_NEAR(compute_psir[i].real(), compute_psir_batch[i].real(), 1e-4);
136+
EXPECT_NEAR(compute_psir[i].imag(), compute_psir_batch[i].imag(), 1e-4);
144137
}
145138

146139
}

0 commit comments

Comments
 (0)