Skip to content

Commit e0258ad

Browse files
committed
add gtest
1 parent 1249fbe commit e0258ad

File tree

4 files changed

+52
-49
lines changed

4 files changed

+52
-49
lines changed

source/module_basis/module_pw/kernels/cuda/pw_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ __global__ void set_3d_fft_box_batch(
4141
thrust::complex<FPTYPE>* batch_out = out + batch_idx * nxyz;
4242

4343
const int box_idx = box_index[element_idx];
44-
44+
printf("the batch_idx is %d, the element_idx is %d, the box_idx is %d\n", batch_idx, element_idx, box_idx);
4545
const thrust::complex<FPTYPE> input_val = batch_in[element_idx];
4646
batch_out[box_idx] = input_val;
4747
}

source/module_basis/module_pw/pw_transform_k.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,13 +491,14 @@ void PW_Basis_K::real2recip_gpu(const std::complex<FPTYPE>* in,
491491

492492
const int startig = ik * this->npwk_max;
493493
const int npw_k = this->npwk[ik];
494+
std::cout << "real2recip_gpu: npw_k = " << npw_k << ", nxyz = " << this->nxyz << std::endl;
494495
set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
495496
this->nxyz,
496497
add,
497498
factor,
498499
this->ig2ixyz_k + startig,
499500
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
500-
out);
501+
out,1);
501502
ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
502503
}
503504
template <typename FPTYPE>
@@ -518,18 +519,19 @@ void PW_Basis_K::recip2real_gpu(const std::complex<FPTYPE>* in,
518519

519520
const int startig = ik * this->npwk_max;
520521
const int npw_k = this->npwk[ik];
521-
522522
set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
523+
nxyz,
523524
this->ig2ixyz_k + startig,
524525
in,
525-
this->fft_bundle.get_auxr_3d_data<FPTYPE>());
526+
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
527+
1);
526528
this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>());
527529

528530
set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
529531
add,
530532
factor,
531533
this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
532-
out);
534+
out,1);
533535

534536
ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
535537
}

source/module_basis/module_pw/test_gpu/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ if (USE_CUDA)
33
AddTest(
44
TARGET pw_test_gpu
55
LIBS parameter ${math_libs} base planewave device FFTW3::FFTW3_FLOAT
6-
SOURCES pw_test.cpp pw_basis_C2R.cpp pw_basis_C2C.cpp pw_basis_k_C2C.cpp
6+
SOURCES pw_test.cpp pw_basis_k_C2C.cpp
77
)
88
endif()
99

source/module_basis/module_pw/test_gpu/pw_basis_k_C2C.cpp

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -196,41 +196,41 @@ class PW_BASIS_K_GPU_TEST : public ::testing::Test
196196
}
197197
};
198198

199-
using MixedTypes = ::testing::Types<TypePair<float, base_device::DEVICE_GPU>,
199+
using MixedTypes = ::testing::Types<
200200
TypePair<double, base_device::DEVICE_GPU> >;
201201

202202
TYPED_TEST_CASE(PW_BASIS_K_GPU_TEST, MixedTypes);
203203

204-
TYPED_TEST(PW_BASIS_K_GPU_TEST, Mixing)
205-
{
206-
using T = typename TestFixture::T;
207-
using Device = typename TestFixture::Device;
208-
ModulePW::PW_Basis_K pwtest;
209-
pwtest.set_device("gpu");
210-
pwtest.set_precision("mixing");
211-
pwtest.fft_bundle.setfft("gpu", "mixing");
212-
this->init(pwtest);
213-
int startiz = pwtest.startz_current;
214-
const int nx = pwtest.nx;
215-
const int ny = pwtest.ny;
216-
const int nz = pwtest.nz;
217-
const int nplane = pwtest.nplane;
218-
const int npwk = pwtest.npwk[0];
219-
for (int ixy = 0; ixy < nx * ny; ++ixy)
220-
{
221-
const int offset = ixy * nz + startiz;
222-
const int startz = ixy * nplane;
223-
for (int iz = 0; iz < nplane; ++iz)
224-
{
225-
EXPECT_NEAR(this->tmp[offset + iz].real(), this->h_rhor[startz + iz].real(), 1e-4);
226-
}
227-
}
228-
for (int ig = 0; ig < npwk; ++ig)
229-
{
230-
EXPECT_NEAR(this->h_rhog[ig].real(), this->h_rhogout[ig].real(), 1e-4);
231-
EXPECT_NEAR(this->h_rhog[ig].imag(), this->h_rhogout[ig].imag(), 1e-4);
232-
}
233-
}
204+
// TYPED_TEST(PW_BASIS_K_GPU_TEST, Mixing)
205+
// {
206+
// using T = typename TestFixture::T;
207+
// using Device = typename TestFixture::Device;
208+
// ModulePW::PW_Basis_K pwtest;
209+
// pwtest.set_device("gpu");
210+
// pwtest.set_precision("mixing");
211+
// pwtest.fft_bundle.setfft("gpu", "mixing");
212+
// this->init(pwtest);
213+
// int startiz = pwtest.startz_current;
214+
// const int nx = pwtest.nx;
215+
// const int ny = pwtest.ny;
216+
// const int nz = pwtest.nz;
217+
// const int nplane = pwtest.nplane;
218+
// const int npwk = pwtest.npwk[0];
219+
// for (int ixy = 0; ixy < nx * ny; ++ixy)
220+
// {
221+
// const int offset = ixy * nz + startiz;
222+
// const int startz = ixy * nplane;
223+
// for (int iz = 0; iz < nplane; ++iz)
224+
// {
225+
// EXPECT_NEAR(this->tmp[offset + iz].real(), this->h_rhor[startz + iz].real(), 1e-4);
226+
// }
227+
// }
228+
// for (int ig = 0; ig < npwk; ++ig)
229+
// {
230+
// EXPECT_NEAR(this->h_rhog[ig].real(), this->h_rhogout[ig].real(), 1e-4);
231+
// EXPECT_NEAR(this->h_rhog[ig].imag(), this->h_rhogout[ig].imag(), 1e-4);
232+
// }
233+
// }
234234

235235
TYPED_TEST(PW_BASIS_K_GPU_TEST, FloatDouble)
236236
{
@@ -239,19 +239,20 @@ TYPED_TEST(PW_BASIS_K_GPU_TEST, FloatDouble)
239239
ModulePW::PW_Basis_K pwtest;
240240
pwtest.set_device("gpu");
241241
pwtest.set_precision("mixing");
242-
if (typeid(T) == typeid(float))
243-
{
244-
pwtest.fft_bundle.setfft("gpu", "single");
245-
}
246-
else if (typeid(T) == typeid(double))
247-
{
242+
// if (typeid(T) == typeid(float))
243+
// {
244+
// pwtest.fft_bundle.setfft("gpu", "single");
245+
// }
246+
// if (typeid(T) == typeid(double))
247+
// {
248+
std::cout << "Using double precision" << std::endl;
248249
pwtest.fft_bundle.setfft("gpu", "double");
249-
}
250-
else
251-
{
252-
cout << "Error: Unsupported type" << endl;
253-
return;
254-
}
250+
// }
251+
// else
252+
// {
253+
// cout << "Error: Unsupported type" << endl;
254+
// return;
255+
// }
255256
this->init(pwtest);
256257
int startiz = pwtest.startz_current;
257258
const int nx = pwtest.nx;

0 commit comments

Comments
 (0)