Skip to content

Commit 12c1c40

Browse files
committed
add the file
1 parent 9ca2e6c commit 12c1c40

File tree

5 files changed

+117
-15
lines changed

5 files changed

+117
-15
lines changed

source/module_basis/module_pw/kernels/cuda/pw_op.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ __global__ void set_3d_fft_box_batch(
4141
thrust::complex<FPTYPE>* batch_out = out + batch_idx * nxyz;
4242

4343
const int box_idx = box_index[element_idx];
44-
printf("the batch_idx is %d, the element_idx is %d, the box_idx is %d\n", batch_idx, element_idx, box_idx);
4544
const thrust::complex<FPTYPE> input_val = batch_in[element_idx];
4645
batch_out[box_idx] = input_val;
4746
}

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void FFT_Bundle::initfft(int nx_in,
4545
bool xprime_in,
4646
bool mpifft_in)
4747
{
48-
assert(this->device == "cpu" || this->device == "gpu" || this->device == "dsp");
48+
assert(this->device == "cpu" || this->device == "gpu" || this->device == "dsp" || this->device == "gpu_batch");
4949
assert(this->precision == "single" || this->precision == "double" || this->precision == "mixing");
5050

5151
if (this->precision == "single" || this->precision == "mixing")
@@ -101,11 +101,25 @@ void FFT_Bundle::initfft(int nx_in,
101101
fft_double = make_unique<FFT_ROCM<double>>();
102102
fft_double->initfft(nx_in, ny_in, nz_in);
103103
#elif defined(__CUDA)
104+
std::cout<<"here is the set of the gpu"<<std::endl;
104105
fft_float = make_unique<FFT_CUDA<float>>();
105106
fft_float->initfft(nx_in, ny_in, nz_in);
107+
fft_double = make_unique<FFT_CUDA<double>>();
108+
fft_double->initfft(nx_in, ny_in, nz_in );
109+
#endif
110+
}else if (device == "gpu_batch")
111+
{
112+
#if defined(__ROCM)
113+
fft_float = make_unique<FFT_ROCM<float>>();
114+
fft_float->initfft(nx_in, ny_in, nz_in);
115+
fft_double = make_unique<FFT_ROCM<double>>();
116+
fft_double->initfft(nx_in, ny_in, nz_in);
117+
#elif defined(__CUDA)
118+
std::cout<<"here is the set of the batch gpu"<<std::endl;
119+
fft_float = make_unique<FFT_CUDA_BATCH<float>>();
120+
fft_float->initfft(nx_in, ny_in, nz_in);
106121
fft_double = make_unique<FFT_CUDA_BATCH<double>>();
107122
fft_double->initfft(nx_in, ny_in, nz_in );
108-
109123
#endif
110124
}else{
111125
// ModuleBase::WARNING_QUIT("FFT_Bundle", "Please set the device to cpu or gpu or dsp");
@@ -238,6 +252,7 @@ template <>
238252
void FFT_Bundle::fft3D_forward(std::complex<double>* in,
239253
std::complex<double>* out) const
240254
{
255+
std::cout<<"FFT_Bundle::fft3D_forward<double> in FFT_bundle"<<std::endl;
241256
fft_double->fft3D_forward(in, out);
242257
}
243258

source/module_basis/module_pw/module_fft/fft_cuda.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@ void FFT_CUDA<float>::fft3D_forward(std::complex<float>* in, std::complex<float>
7474
template <>
7575
void FFT_CUDA<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
7676
{
77+
std::cout<<"FFT_CUDA<double>::fft3D_forward"<<std::endl;
7778
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
7879
reinterpret_cast<cufftDoubleComplex*>(in),
7980
reinterpret_cast<cufftDoubleComplex*>(out),
8081
CUFFT_FORWARD));
82+
cudaCheckOnDebug();
8183
}
8284
template <>
8385
void FFT_CUDA<float>::fft3D_backward(std::complex<float>* in, std::complex<float>* out) const

source/module_basis/module_pw/module_fft/fft_cuda_batch.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ void FFT_CUDA_BATCH<float>::setupFFT()
2828
template <>
2929
void FFT_CUDA_BATCH<double>::setupFFT()
3030
{
31-
std::cout<<"the nx ,ny,nz,batch is: "
32-
<<this->nx<<" "<<this->ny<<" "<<this->nz<<" "<<this->batch<<std::endl;
3331
int rank = 3;
3432
int n[3] = {this->nx, this->ny, this->nz};
3533
const int size = this->nx* this->ny *this->nz;

source/module_basis/module_pw/test_gpu/pw_basis_k_batch.cpp

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "module_base/module_device/device.h"
44
#include "module_base/vector3.h"
55
#include "module_basis/module_pw/pw_basis_k.h"
6+
#include "module_basis/module_pw/module_fft/fft_bundle.h"
67
#include "pw_test.h"
78
#include <complex>
89
#include <vector>
@@ -14,45 +15,132 @@ using namespace std;
1415
class PW_BASIS_K_BATCH_GPU_TEST : public ::testing::Test
1516
{
1617
public:
17-
const int batch = 10; // Number of batches
18+
const int batch = 1; // Number of batches
1819
const int npwk = 30; // Number of planewaves
1920
const int nxyz = 1000; // Size of the 3D grid
2021
std::vector<int> box_index; // Index mapping for 3D grid
2122
int* d_box_index=nullptr; // Device memory for box_index
22-
std::vector<std::complex<double>> rhog; // Input data for the test,
23+
std::vector<std::complex<double>> rhog; // rhoG(K space) data for the test,
2324
std::complex<double>* d_rhog = nullptr; // Device memory for rhoG data
24-
std::vector<std::complex<double>> rhor = nullptr; // Device memory for output rhoG data
25-
std::complex<double>* d_rhor = nullptr; // Device memory for output data
25+
std::complex<double>* d_rhog_batch = nullptr; // Device memory for rhoG output data
26+
std::vector<std::complex<double>> rhor; // Device memory for output rhoR(R space) data
27+
std::complex<double>* d_rhor = nullptr; // Device memory for rhoR data
28+
std::complex<double>* d_rhor_batch = nullptr; // Device memory for rhoR output batch data
29+
ModulePW::FFT_Bundle ft_gpu; // FFT bundle for 3D FFT operations on GPU
30+
ModulePW::FFT_Bundle ft_gpu_batch; // FFT bundle for 3D FFT operations on batch-GPU
2631
void SetUp() override
2732
{
2833
box_index.resize(npwk);
29-
rhog.resize(npwk);
34+
rhog.resize(npwk * batch);
35+
rhor.resize(nxyz * batch);
36+
3037
resize_memory_int_gpu_op()(d_box_index, npwk);
31-
resize_memory_complex_gpu_op()(d_rhog, npwk);
38+
resize_memory_complex_gpu_op()(d_rhog, npwk * batch);
39+
resize_memory_complex_gpu_op()(d_rhor, nxyz * batch);
40+
resize_memory_complex_gpu_op()(d_rhog_batch, npwk * batch);
41+
resize_memory_complex_gpu_op()(d_rhor_batch, nxyz * batch);
42+
3243
// Initialize the box_index and input with some values
3344
int idx = 0;
3445
std::generate_n(box_index.begin(), npwk, [&idx] { return idx * idx++; });
3546
idx =0;
36-
std::generate_n(rhog.begin(), npwk, [&idx] { return std::complex<double>(sqrt(idx), 1/(idx+1)); });
47+
int npwk = box_index.size();
48+
// Initialize rhog with some complex values,it generates a complex number
49+
// with real part as sqrt(idx) and imaginary part as 1/(idx+1),
50+
// thus in different batches the values of rhog will be different.
51+
std::generate_n(rhog.begin(), npwk * batch, [&idx,npwk]
52+
{
53+
idx ++;
54+
return std::complex<double>(std::sqrt(idx), 1.0/(idx+1));
55+
});
3756
synchronize_memory_int_h2d_op()(d_box_index, box_index.data(), npwk);
38-
synchronize_memory_complex_h2d_op()(d_rhog, rhog.data(), npwk);
57+
synchronize_memory_complex_h2d_op()(d_rhog, rhog.data(), npwk * batch);
58+
synchronize_memory_complex_h2d_op()(d_rhog_batch, rhog.data(), npwk * batch);
3959
// Initialize the box_index with some values
40-
41-
// resize_memory_int_gpu_op
60+
ft_gpu.setfft("gpu", "double");
61+
ft_gpu.initfft(10, 10, 10 , 1, 1, 1, 1, 1, 1);
62+
ft_gpu.setupFFT();
63+
ft_gpu_batch.setfft("gpu", "double");
64+
ft_gpu_batch.initfft(10, 10, 10 , 1, 1, 1, 1, 1, 1);
65+
ft_gpu_batch.setupFFT();
4266
}
4367
void TearDown() override
4468
{
4569
box_index.clear();
4670
rhog.clear();
71+
rhor.clear();
4772
delete_memory_int_gpu_op()(d_box_index);
4873
delete_memory_complex_gpu_op()(d_rhog);
74+
delete_memory_complex_gpu_op()(d_rhor);
75+
ft_gpu.clear();
76+
ft_gpu_batch.clear();
4977
}
5078
};
5179

5280
TEST_F(PW_BASIS_K_BATCH_GPU_TEST,convulution)
5381
{
82+
// STEP 1 set the 3D FFT box operation for CPU
5483
for (int i = 0; i < npwk; ++i)
5584
{
5685
EXPECT_EQ(box_index[i], i * i);
5786
}
87+
88+
// STEP 2 check the input rhog has been
89+
// correctly mapped to the 3D grid
90+
std::vector<std::complex<double>> compute_rhor(nxyz * batch);
91+
std::vector<std::complex<double>> compute_rhor_batch(nxyz * batch);
92+
for (int i = 0; i< batch; i++)
93+
{
94+
ModulePW::set_3d_fft_box_op<double,
95+
base_device::DEVICE_GPU>()
96+
(
97+
npwk,
98+
d_box_index,
99+
d_rhog + i * npwk,
100+
d_rhor + i * nxyz
101+
);
102+
synchronize_memory_complex_d2h_op()(compute_rhor.data()+i * nxyz, d_rhor + i *nxyz, nxyz);
103+
}
104+
ModulePW::set_3d_fft_box_op<double,
105+
base_device::DEVICE_GPU>()
106+
(
107+
npwk,
108+
nxyz,
109+
d_box_index,
110+
d_rhog_batch,
111+
d_rhor_batch,
112+
batch
113+
);
114+
115+
synchronize_memory_complex_d2h_op()(compute_rhor_batch.data(), d_rhor_batch,nxyz * batch);
116+
for (int i = 0; i < nxyz*batch ; ++i)
117+
{
118+
EXPECT_NEAR(compute_rhor[i].real(), compute_rhor_batch[i].real(), 1e-7);
119+
EXPECT_NEAR(compute_rhor[i].imag(), compute_rhor_batch[i].imag(), 1e-7);
120+
}
121+
122+
// STEP 3 perform the 3D FFT forward operation
123+
std::vector<std::complex<double>> compute_rhor1(nxyz * batch,0);
124+
std::vector<std::complex<double>> compute_rhor_batch1(nxyz * batch,0);
125+
126+
for (int i=0;i< batch; i++)
127+
{
128+
ft_gpu.fft3D_forward(d_rhor, d_rhor);
129+
synchronize_memory_complex_d2h_op()(compute_rhor1.data(), d_rhor , nxyz);
130+
}
131+
// ft_gpu.fft3D_backward(d_rhor, d_rhor);
132+
for (int i=0;i< batch; i++)
133+
{
134+
ft_gpu_batch.fft3D_forward(d_rhor_batch, d_rhor_batch);
135+
synchronize_memory_complex_d2h_op()(compute_rhor_batch1.data(), d_rhor_batch , nxyz);
136+
}
137+
// ft_gpu.fft3D_forward(d_rhor_batch, d_rhor_batch);
138+
// synchronize_memory_complex_d2h_op()(compute_rhor_batch.data(), d_rhor_batch,nxyz * batch);
139+
140+
for (int i = 0; i < nxyz *batch ; ++i)
141+
{
142+
EXPECT_NEAR(compute_rhor1[i].real(), compute_rhor_batch1[i].real(), 1e-4);
143+
// EXPECT_NEAR(compute_rhor[i].imag(), compute_rhor_batch[i].imag(), 1e-4);
144+
}
145+
58146
}

0 commit comments

Comments
 (0)