33#include " module_base/module_device/device.h"
44#include " module_base/vector3.h"
55#include " module_basis/module_pw/pw_basis.h"
6+ #include " module_basis/module_pw/pw_basis_k.h"
67
78#include < complex>
89#include < gtest/gtest.h>
@@ -17,7 +18,7 @@ struct TypePair
1718};
1819
1920template <typename TypePair>
20- class MixedTypeTest : public ::testing::Test
21+ class PW_BASIS_K_GPU_TEST : public ::testing::Test
2122{
2223 public:
2324 using T = typename TypePair::T;
@@ -26,24 +27,26 @@ class MixedTypeTest : public ::testing::Test
2627 complex <T>* d_rhog = nullptr ;
2728 complex <T>* d_rhogr = nullptr ;
2829 complex <T>* d_rhogout = nullptr ;
29- T * d_rhor = nullptr ;
30- complex <T>* tmp;
31- complex <T>* h_rhog;
32- complex <T>* h_rhogout;
33- T * h_rhor;
30+ complex <T> * d_rhor = nullptr ;
31+ complex <T>* tmp = nullptr ;
32+ complex <T>* h_rhog = nullptr ;
33+ complex <T>* h_rhogout = nullptr ;
34+ complex <T> * h_rhor = nullptr ;
3435 void init (ModulePW::PW_Basis& pwtest)
3536 {
36- cout << " dividemthd 1, gamma_only: off, check fft between T and complex" << endl;
37-
3837 ModuleBase::Matrix3 latvec (1 , 1 , 0 , 0 , 1 , 1 , 0 , 0 , 2 );
3938 T wfcecut;
4039 T lat0 = 2.2 ;
40+
4141 bool gamma_only = false ;
4242 wfcecut = 18 ;
4343 gamma_only = false ;
4444 int distribution_type = 1 ;
4545 bool xprime = false ;
46-
46+ const int nks = 1 ;
47+ ModuleBase::Vector3<double >* kvec_d;
48+ kvec_d = new ModuleBase::Vector3<double >[nks];
49+ kvec_d[0 ].set (0 , 0 , 0 );
4750 // init
4851 const int mypool = 0 ;
4952 const int key = 1 ;
@@ -52,7 +55,6 @@ class MixedTypeTest : public ::testing::Test
5255 MPI_Comm POOL_WORLD;
5356 MPI_Comm_split (MPI_COMM_WORLD, mypool, key, &POOL_WORLD);
5457 pwtest.initmpi (nproc_in_pool, rank_in_pool, POOL_WORLD);
55-
5658 pwtest.initgrids (lat0, latvec, wfcecut);
5759 pwtest.initparameters (gamma_only, wfcecut, distribution_type, xprime);
5860 pwtest.setuptransform ();
@@ -73,95 +75,104 @@ class MixedTypeTest : public ::testing::Test
7375 G = GT.Transpose ();
7476 GGT = G * GT;
7577 tmp = new complex <T>[nx * ny * nz];
76- if (rank_in_pool == 0 )
77- {
78- for (int ix = 0 ; ix < nx; ++ix)
78+ if (rank_in_pool == 0 )
7979 {
80- const T vx = ix - int (nx / 2 );
81- for (int iy = 0 ; iy < ny; ++iy)
80+ for (int ix = 0 ; ix < nx; ++ix)
8281 {
83- const int offset = (ix * ny + iy) * nz;
84- const T vy = iy - int (ny / 2 );
85- for (int iz = 0 ; iz < nz; ++iz)
82+ const T vx = ix - int (nx / 2 );
83+ for (int iy = 0 ; iy < ny; ++iy)
8684 {
87- tmp[offset + iz] = 0.0 ;
88- T vz = iz - int (nz / 2 );
89- ModuleBase::Vector3<double > v (vx, vy, vz);
90- T modulus = v * (GGT * v);
91- if (modulus <= ggecut)
85+ const int offset = (ix * ny + iy) * nz;
86+ const T vy = iy - int (ny / 2 );
87+ for (int iz = 0 ; iz < nz; ++iz)
9288 {
93- tmp[offset + iz] = 1.0 / (modulus + 1 );
94- if (vy > 0 )
89+ tmp[offset + iz] = 0.0 ;
90+ T vz = iz - int (nz / 2 );
91+ ModuleBase::Vector3<double > v (vx, vy, vz);
92+ T modulus = v * (GGT * v);
93+ if (modulus <= ggecut)
9594 {
96- tmp[offset + iz] += std::complex <T>(0 ,1.0 ) / (std::abs (static_cast <T>(v.x ) + 1 ) + 1 );
97- }
98- else if (vy < 0 )
99- {
100- tmp[offset + iz] -= std::complex <T>(0 ,1.0 ) / (std::abs (-static_cast <T>(v.x ) + 1 ) + 1 );
95+ tmp[offset + iz] = 1.0 / (modulus + 1 );
96+ if (vy > 0 )
97+ {
98+ tmp[offset + iz]
99+ += std::complex <T>(0 , 1.0 ) / (std::abs (static_cast <T>(v.x ) + 1 ) + 1 );
100+ }
101+ else if (vy < 0 )
102+ {
103+ tmp[offset + iz]
104+ -= std::complex <T>(0 , 1.0 ) / (std::abs (-static_cast <T>(v.x ) + 1 ) + 1 );
105+ }
101106 }
102107 }
103108 }
104109 }
105- }
106- if (typeid (T)==typeid (double ))
107- {
108- fftw_plan pp
109- = fftw_plan_dft_3d (nx, ny, nz, (fftw_complex*)tmp, (fftw_complex*)tmp, FFTW_BACKWARD, FFTW_ESTIMATE);
110- fftw_execute (pp);
111- fftw_destroy_plan (pp);
112- }else if (typeid (T)==typeid (float )){
113- fftwf_plan pp
114- = fftwf_plan_dft_3d (nx, ny, nz, (fftwf_complex*)tmp, (fftwf_complex*)tmp, FFTW_BACKWARD, FFTW_ESTIMATE);
115- fftwf_execute (pp);
116- fftwf_destroy_plan (pp);
117- }
118- ModuleBase::Vector3<T> delta_g (T (int (nx / 2 )) / nx,
119- T (int (ny / 2 )) / ny,
120- T (int (nz / 2 )) / nz);
121- for (int ixy = 0 ; ixy < nx * ny; ++ixy)
122- {
123- const int ix = ixy / ny;
124- const int iy = ixy % ny;
125- for (int iz = 0 ; iz < nz; ++iz)
110+ if (typeid (T) == typeid (double ))
126111 {
127- ModuleBase::Vector3<T> real_r (ix, iy, iz);
128- T phase_im = -delta_g * real_r;
129- complex <T> phase (0 , ModuleBase::TWO_PI * phase_im);
130- tmp[ixy * nz + iz] *= exp (phase);
112+ fftw_plan pp = fftw_plan_dft_3d (nx,
113+ ny,
114+ nz,
115+ (fftw_complex*)tmp,
116+ (fftw_complex*)tmp,
117+ FFTW_BACKWARD,
118+ FFTW_ESTIMATE);
119+ fftw_execute (pp);
120+ fftw_destroy_plan (pp);
121+ }
122+ else if (typeid (T) == typeid (float ))
123+ {
124+ fftwf_plan pp = fftwf_plan_dft_3d (nx,
125+ ny,
126+ nz,
127+ (fftwf_complex*)tmp,
128+ (fftwf_complex*)tmp,
129+ FFTW_BACKWARD,
130+ FFTW_ESTIMATE);
131+ fftwf_execute (pp);
132+ fftwf_destroy_plan (pp);
133+ }
134+ ModuleBase::Vector3<T> delta_g (T (int (nx / 2 )) / nx, T (int (ny / 2 )) / ny, T (int (nz / 2 )) / nz);
135+ for (int ixy = 0 ; ixy < nx * ny; ++ixy)
136+ {
137+ const int ix = ixy / ny;
138+ const int iy = ixy % ny;
139+ for (int iz = 0 ; iz < nz; ++iz)
140+ {
141+ ModuleBase::Vector3<T> real_r (ix, iy, iz);
142+ T phase_im = -delta_g * real_r;
143+ complex <T> phase (0 , ModuleBase::TWO_PI * phase_im);
144+ tmp[ixy * nz + iz] *= exp (phase);
145+ }
131146 }
132- }
133- }
134- h_rhog = new complex <T>[npw];
135- h_rhogout = new complex <T>[npw];
136-
137- cudaMalloc ((void **)&d_rhog, npw * sizeof (complex <T>));
138- cudaMalloc ((void **)&d_rhogr, npw * sizeof (complex <T>));
139- cudaMalloc ((void **)&d_rhogout, npw * sizeof (complex <T>));
140147
141- for (int ig = 0 ; ig < npw; ++ig)
142- {
143- h_rhog[ig] = 1.0 / (pwtest.gg [ig] + 1 );
144- if (pwtest.gdirect [ig].y > 0 )
145- {
146- h_rhog[ig] += ModuleBase::IMAG_UNIT / (std::abs (pwtest.gdirect [ig].x + 1 ) + 1 );
147- }
148- else if (pwtest.gdirect [ig].y < 0 )
149- {
150- h_rhog[ig] -= ModuleBase::IMAG_UNIT / (std::abs (-pwtest.gdirect [ig].x + 1 ) + 1 );
151- }
152- }
153- cudaMemcpy (d_rhog, h_rhog, npw * sizeof (complex <T>), cudaMemcpyHostToDevice);
148+ h_rhog = new complex <T>[npw];
149+ h_rhogout = new complex <T>[npw];
150+ for (int ig = 0 ; ig < npw; ++ig)
151+ {
152+ h_rhog[ig] = 1.0 / (pwtest.gg [ig] + 1 );
153+
154+ if (pwtest.gdirect [ig].y > 0 )
155+ {
156+ h_rhog[ig] += std::complex <float >(0 , 1.0 ) / (std::abs (float (pwtest.gdirect [ig].x ) + 1 ) + 1 );
157+ }
158+ else if (pwtest.gdirect [ig].y < 0 )
159+ {
160+ h_rhog[ig] -= std::complex <float >(0 , 1.0 ) / (std::abs (float (-pwtest.gdirect [ig].x ) + 1 ) + 1 );
161+ }
162+ }
154163
155- h_rhor = new T[nrxx];
164+ cudaMalloc ((void **)&d_rhog, npw * sizeof (complex <T>));
165+ cudaMalloc ((void **)&d_rhor, nrxx * sizeof (complex <T>));
166+ cudaMemcpy (d_rhog, h_rhog, npw * sizeof (complex <T>), cudaMemcpyHostToDevice);
156167
157- cudaMalloc ((void **)&d_rhor, nrxx * sizeof (T));
158- pwtest.recip_to_real <std::complex <T>, T, base_device::DEVICE_GPU>(d_rhog, d_rhor);
159- cudaMemcpy (h_rhor, d_rhor, nrxx * sizeof (T), cudaMemcpyDeviceToHost);
168+ h_rhor = new complex <T>[nrxx];
160169
161- pwtest.real_to_recip <T, std::complex <T>,base_device::DEVICE_GPU>(d_rhor, d_rhog);
162- cudaMemcpy (h_rhogout,d_rhog,npw * sizeof (complex <T>),cudaMemcpyDeviceToHost);
170+ pwtest.recip_to_real <std:: complex <T>, std::complex <T>,base_device::DEVICE_GPU>(d_rhog, d_rhor );
171+ cudaMemcpy (h_rhor, d_rhor, nrxx * sizeof (complex <T>), cudaMemcpyDeviceToHost);
163172
164-
173+ pwtest.real_to_recip <std::complex <T>, std::complex <T>,base_device::DEVICE_GPU>(d_rhor, d_rhog);
174+ cudaMemcpy (h_rhogout, d_rhog, npw * sizeof (complex <T>), cudaMemcpyDeviceToHost);
175+ }
165176 }
166177 ModulePW::PW_Basis* access_pw ()
167178 {
@@ -180,20 +191,18 @@ class MixedTypeTest : public ::testing::Test
180191 }
181192};
182193
183- // 类型参数列表(保持语义明确)
184- using MixedTypes = ::testing::Types<TypePair<float , base_device::DEVICE_GPU>, // float转double计算
185- TypePair<double , base_device::DEVICE_GPU> // double转float计算
186- >;
194+ using MixedTypes = ::testing::Types<TypePair<float , base_device::DEVICE_GPU>,
195+ TypePair<double , base_device::DEVICE_GPU> >;
187196
188- TYPED_TEST_CASE (MixedTypeTest , MixedTypes);
197+ TYPED_TEST_CASE (PW_BASIS_K_GPU_TEST , MixedTypes);
189198
190- TYPED_TEST (MixedTypeTest , Mixing)
199+ TYPED_TEST (PW_BASIS_K_GPU_TEST , Mixing)
191200{
192201 using T = typename TestFixture::T;
193202 using Device = typename TestFixture::Device;
194203 ModulePW::PW_Basis pwtest;
195204 pwtest.set_device (" gpu" );
196- pwtest.set_precision (" double " );
205+ pwtest.set_precision (" mixing " );
197206 pwtest.fft_bundle .setfft (" gpu" , " mixing" );
198207 this ->init (pwtest);
199208 int startiz = pwtest.startz_current ;
@@ -208,24 +217,23 @@ TYPED_TEST(MixedTypeTest, Mixing)
208217 const int startz = ixy * nplane;
209218 for (int iz = 0 ; iz < nplane; ++iz)
210219 {
211- EXPECT_NEAR (this ->tmp [offset + iz].real (),
212- this ->h_rhor [startz + iz], 1e-4 );
220+ EXPECT_NEAR (this ->tmp [offset + iz].real (), this ->h_rhor [startz + iz].real (), 1e-4 );
213221 }
214222 }
215- for (int ig = 0 ; ig < pwtest. npw ; ++ig)
223+ for (int ig = 0 ; ig < npw; ++ig)
216224 {
217225 EXPECT_NEAR (this ->h_rhog [ig].real (), this ->h_rhogout [ig].real (), 1e-4 );
218226 EXPECT_NEAR (this ->h_rhog [ig].imag (), this ->h_rhogout [ig].imag (), 1e-4 );
219227 }
220228}
221229
222- TYPED_TEST (MixedTypeTest , FloatDouble)
230+ TYPED_TEST (PW_BASIS_K_GPU_TEST , FloatDouble)
223231{
224232 using T = typename TestFixture::T;
225233 using Device = typename TestFixture::Device;
226234 ModulePW::PW_Basis pwtest;
227235 pwtest.set_device (" gpu" );
228- pwtest.set_precision (" double " );
236+ pwtest.set_precision (" mixing " );
229237 if (typeid (T) == typeid (float ))
230238 {
231239 pwtest.fft_bundle .setfft (" gpu" , " single" );
@@ -252,12 +260,11 @@ TYPED_TEST(MixedTypeTest, FloatDouble)
252260 const int startz = ixy * nplane;
253261 for (int iz = 0 ; iz < nplane; ++iz)
254262 {
255- EXPECT_NEAR (this ->tmp [offset + iz].real (),
256- this ->h_rhor [startz + iz], 1e-4 );
263+ EXPECT_NEAR (this ->tmp [offset + iz].real (), this ->h_rhor [startz + iz].real (), 1e-4 );
257264 }
258265 }
259266
260- for (int ig = 0 ; ig < pwtest. npw ; ++ig)
267+ for (int ig = 0 ; ig < npw; ++ig)
261268 {
262269 EXPECT_NEAR (this ->h_rhog [ig].real (), this ->h_rhogout [ig].real (), 1e-4 );
263270 EXPECT_NEAR (this ->h_rhog [ig].imag (), this ->h_rhogout [ig].imag (), 1e-4 );
0 commit comments