11#include " ./pw_basis.h"
22#include " ../module_base/constants.h"
3- #include " ../module_base/timer.h"
3+ #ifdef __MPI
4+ #include " mpi.h"
5+ #include " ../src_parallel/parallel_global.h"
6+ #endif
47
58namespace ModulePW
69{
@@ -12,16 +15,19 @@ namespace ModulePW
1215void PW_Basis:: initgrids(
1316 double lat0_in, // unit length (unit in bohr)
1417 ModuleBase::Matrix3 latvec_in, // Unitcell lattice vectors
15- double gridecut
18+ double gridecut,
19+ int poolnproc_in,
20+ int poolrank_in
1621)
1722{
18- // ModuleBase::timer::start();
19- // init latice
23+ // init lattice
2024 this ->lat0 = lat0_in;
2125 this ->latvec = latvec_in;
2226 this ->GT = latvec.Inverse ();
2327 this ->G = GT.Transpose ();
2428 this ->GGT = G * GT;
29+ this ->poolnproc = poolnproc_in;
30+ this ->poolrank = poolrank_in;
2531
2632
2733 // ------------------------------------------------------------
@@ -35,18 +41,46 @@ void PW_Basis:: initgrids(
3541 lat.x = latvec.e11 ;
3642 lat.y = latvec.e12 ;
3743 lat.z = latvec.e13 ;
38- ibox[0 ] = 2 * int (sqrt (gridecut) * sqrt (lat * lat)) + 1 ;
44+ ibox[0 ] = int (sqrt (gridecut) * sqrt (lat * lat)) + 1 ;
3945
4046 lat.x = latvec.e21 ;
4147 lat.y = latvec.e22 ;
4248 lat.z = latvec.e23 ;
43- ibox[1 ] = 2 * int (sqrt (gridecut) * sqrt (lat * lat)) + 1 ;
49+ ibox[1 ] = int (sqrt (gridecut) * sqrt (lat * lat)) + 1 ;
4450
4551 lat.x = latvec.e31 ;
4652 lat.y = latvec.e32 ;
4753 lat.z = latvec.e33 ;
48- ibox[2 ] = 2 * int (sqrt (gridecut) * sqrt (lat * lat)) + 1 ;
49- // lat*lat=lat.x*lat.x+lat.y*lat.y+lat.z+lat.z
54+ ibox[2 ] = int (sqrt (gridecut) * sqrt (lat * lat)) + 1 ;
55+
56+ int n1,n2,n3;
57+ n1 = n2 = n3 = 0 ;
58+ for (int igz = -ibox[2 ]+this ->poolrank ; igz <= ibox[2 ]; igz += this ->poolnproc )
59+ {
60+ for (int igy = -ibox[1 ]; igy <= ibox[1 ]; ++igy)
61+ {
62+ for (int igx = -ibox[0 ]; igx <= ibox[0 ]; ++igx)
63+ {
64+ ModuleBase::Vector3<double > f;
65+ f.x = igx;
66+ f.y = igy;
67+ f.z = igz;
68+ double modulus = f * (this ->GGT * f);
69+ if (modulus <= gridecut)
70+ {
71+ if (n1 < abs (igx)) n1 = abs (igx);
72+ if (n2 < abs (igy)) n2 = abs (igy);
73+ if (n3 < abs (igz)) n3 = abs (igz);
74+ }
75+ }
76+ }
77+ }
78+ ibox[0 ] = 2 *n1+1 ;
79+ ibox[1 ] = 2 *n2+1 ;
80+ ibox[2 ] = 2 *n3+1 ;
81+ #ifdef __MPI
82+ MPI_Allreduce (MPI_IN_PLACE, ibox, 3 , MPI_INT, MPI_MAX , POOL_WORLD);
83+ #endif
5084
5185 // Find the minimal FFT box size the factors into the primes (2,3,5,7).
5286 for (int i = 0 ; i < 3 ; i++)
@@ -62,7 +96,6 @@ void PW_Basis:: initgrids(
6296 // increase ibox[i] by 1 until it is totally factorizable by (2,3,5,7)
6397 do
6498 {
65- ibox[i] += 1 ;
6699 b = ibox[i];
67100
68101 // n2 = n3 = n5 = n7 = 0;
@@ -91,8 +124,10 @@ void PW_Basis:: initgrids(
91124 // if (b%7==0) { n7++; b /= 7; continue; }
92125 done_factoring = true ;
93126 }
127+ ibox[i] += 1 ;
94128 }
95129 while (b != 1 );
130+ ibox[i] -= 1 ;
96131 // b==1 means fftbox[i] is (2,3,5,7) factorizable
97132 }
98133 this ->nx = ibox[0 ];
@@ -113,7 +148,9 @@ void PW_Basis:: initgrids(
113148void PW_Basis:: initgrids(
114149 double lat0_in,
115150 ModuleBase::Matrix3 latvec_in, // Unitcell lattice vectors
116- int nx_in, int bigny_in, int nz_in
151+ int nx_in, int bigny_in, int nz_in,
152+ int poolnproc_in,
153+ int poolrank_in
117154)
118155{
119156 this ->lat0 = lat0_in;
@@ -126,6 +163,8 @@ void PW_Basis:: initgrids(
126163 this ->nz = nz_in;
127164 this ->bignxy = this ->nx * this ->bigny ;
128165 this ->bignxyz = this ->bignxy * this ->nz ;
166+ this ->poolnproc = poolnproc_in;
167+ this ->poolrank = poolrank_in;
129168
130169 return ;
131170}
@@ -135,8 +174,6 @@ void PW_Basis:: initgrids(
135174void PW_Basis:: initparameters(
136175 bool gamma_only_in,
137176 double pwecut_in,
138- int poolnproc_in,
139- int poolrank_in,
140177 int distribution_type_in
141178)
142179{
@@ -150,8 +187,6 @@ void PW_Basis:: initparameters(
150187
151188 double tpiba2 = ModuleBase::TWO_PI * ModuleBase::TWO_PI / this ->lat0 / this ->lat0 ;
152189 this ->ggecut = pwecut_in / tpiba2;
153- this ->poolnproc = poolnproc_in;
154- this ->poolrank = poolrank_in;
155190 this ->distribution_type = distribution_type_in;
156191}
157192
0 commit comments