Skip to content

Commit a63a862

Browse files
committed
fix : fix a bug in initgrids
we can use less fft grids scope: pw_init.cpp, pw_basis.h, pw_basis_k.cpp/h
1 parent 1fe4b5d commit a63a862

File tree

14 files changed

+75
-44
lines changed

14 files changed

+75
-44
lines changed

source/module_pw/pw_basis.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,23 @@ class PW_Basis
2929
void initgrids(
3030
double lat0_in, //unit length (unit in bohr)
3131
ModuleBase::Matrix3 latvec_in, // Unitcell lattice vectors (unit in lat0)
32-
double gridecut //unit in Ry, ecut to set up grids
32+
double gridecut, //unit in Ry, ecut to set up grids
33+
int poolnproc_in, // Number of processors in this pool
34+
int poolrank_in // Rank in this pool
3335
);
3436
//Init the grids for FFT
3537
void initgrids(
3638
double lat0_in,
3739
ModuleBase::Matrix3 latvec_in, // Unitcell lattice vectors
38-
int nx_in, int bigny_in, int nz_in
40+
int nx_in, int bigny_in, int nz_in,
41+
int poolnproc_in, // Number of processors in this pool
42+
int poolrank_in // Rank in this pool
3943
);
4044

4145
//Init some parameters
4246
void initparameters(
4347
bool gamma_only_in,
4448
double pwecut_in, //unit in Ry, ecut to decides plane waves
45-
int poolnproc_in, // Number of processors in this pool
46-
int poolrank_in, // Rank in this pool
4749
int distribution_type_in
4850
);
4951

source/module_pw/pw_basis_k.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ void PW_Basis_K:: initparameters(
2424
double gk_ecut_in,
2525
int nks_in, //number of k points in this pool
2626
ModuleBase::Vector3<double> *kvec_d_in, // Direct coordinates of k points
27-
int poolnproc_in, // Number of processors in this pool
28-
int poolrank_in, // Rank in this pool
2927
int distribution_type_in
3028
)
3129
{
@@ -53,8 +51,6 @@ void PW_Basis_K:: initparameters(
5351
this->nxy = this->nx * this->ny;
5452
this->nxyz = this->nxy * this->nz;
5553

56-
this->poolnproc = poolnproc_in;
57-
this->poolrank = poolrank_in;
5854
this->distribution_type = distribution_type_in;
5955
return;
6056
}

source/module_pw/pw_basis_k.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ class PW_Basis_K : public PW_Basis
2222
double ecut_in,
2323
int nk_in, //number of k points in this pool
2424
ModuleBase::Vector3<double> *kvec_d, // Direct coordinates of k points
25-
int poolnproc_in, // Number of processors in this pool
26-
int poolrank_in, // Rank in this pool
2725
int distribution_type_in
2826
);
2927

source/module_pw/pw_init.cpp

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include "./pw_basis.h"
22
#include "../module_base/constants.h"
3-
#include "../module_base/timer.h"
3+
#ifdef __MPI
4+
#include "mpi.h"
5+
#include "../src_parallel/parallel_global.h"
6+
#endif
47

58
namespace ModulePW
69
{
@@ -12,16 +15,19 @@ namespace ModulePW
1215
void PW_Basis:: initgrids(
1316
double lat0_in, //unit length (unit in bohr)
1417
ModuleBase::Matrix3 latvec_in, // Unitcell lattice vectors
15-
double gridecut
18+
double gridecut,
19+
int poolnproc_in,
20+
int poolrank_in
1621
)
1722
{
18-
// ModuleBase::timer::start();
19-
//init latice
23+
//init lattice
2024
this->lat0 = lat0_in;
2125
this->latvec = latvec_in;
2226
this->GT = latvec.Inverse();
2327
this->G = GT.Transpose();
2428
this->GGT = G * GT;
29+
this->poolnproc = poolnproc_in;
30+
this->poolrank = poolrank_in;
2531

2632

2733
//------------------------------------------------------------
@@ -35,18 +41,46 @@ void PW_Basis:: initgrids(
3541
lat.x = latvec.e11;
3642
lat.y = latvec.e12;
3743
lat.z = latvec.e13;
38-
ibox[0] = 2 * int(sqrt(gridecut) * sqrt(lat * lat)) + 1;
44+
ibox[0] = int(sqrt(gridecut) * sqrt(lat * lat)) + 1;
3945

4046
lat.x = latvec.e21;
4147
lat.y = latvec.e22;
4248
lat.z = latvec.e23;
43-
ibox[1] = 2 * int(sqrt(gridecut) * sqrt(lat * lat)) + 1;
49+
ibox[1] = int(sqrt(gridecut) * sqrt(lat * lat)) + 1;
4450

4551
lat.x = latvec.e31;
4652
lat.y = latvec.e32;
4753
lat.z = latvec.e33;
48-
ibox[2] = 2 * int(sqrt(gridecut) * sqrt(lat * lat)) + 1;
49-
//lat*lat=lat.x*lat.x+lat.y*lat.y+lat.z+lat.z
54+
ibox[2] = int(sqrt(gridecut) * sqrt(lat * lat)) + 1;
55+
56+
int n1,n2,n3;
57+
n1 = n2 = n3 = 0;
58+
for(int igz = -ibox[2]+this->poolrank; igz <= ibox[2]; igz += this->poolnproc)
59+
{
60+
for(int igy = -ibox[1]; igy <= ibox[1]; ++igy)
61+
{
62+
for(int igx = -ibox[0]; igx <= ibox[0]; ++igx)
63+
{
64+
ModuleBase::Vector3<double> f;
65+
f.x = igx;
66+
f.y = igy;
67+
f.z = igz;
68+
double modulus = f * (this->GGT * f);
69+
if(modulus <= gridecut)
70+
{
71+
if(n1 < abs(igx)) n1 = abs(igx);
72+
if(n2 < abs(igy)) n2 = abs(igy);
73+
if(n3 < abs(igz)) n3 = abs(igz);
74+
}
75+
}
76+
}
77+
}
78+
ibox[0] = 2*n1+1;
79+
ibox[1] = 2*n2+1;
80+
ibox[2] = 2*n3+1;
81+
#ifdef __MPI
82+
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , POOL_WORLD);
83+
#endif
5084

5185
// Find the minimal FFT box size the factors into the primes (2,3,5,7).
5286
for (int i = 0; i < 3; i++)
@@ -62,7 +96,6 @@ void PW_Basis:: initgrids(
6296
// increase ibox[i] by 1 until it is totally factorizable by (2,3,5,7)
6397
do
6498
{
65-
ibox[i] += 1;
6699
b = ibox[i];
67100

68101
//n2 = n3 = n5 = n7 = 0;
@@ -91,8 +124,10 @@ void PW_Basis:: initgrids(
91124
//if (b%7==0) { n7++; b /= 7; continue; }
92125
done_factoring = true;
93126
}
127+
ibox[i] += 1;
94128
}
95129
while (b != 1);
130+
ibox[i] -= 1;
96131
// b==1 means fftbox[i] is (2,3,5,7) factorizable
97132
}
98133
this->nx = ibox[0];
@@ -113,7 +148,9 @@ void PW_Basis:: initgrids(
113148
void PW_Basis:: initgrids(
114149
double lat0_in,
115150
ModuleBase::Matrix3 latvec_in, // Unitcell lattice vectors
116-
int nx_in, int bigny_in, int nz_in
151+
int nx_in, int bigny_in, int nz_in,
152+
int poolnproc_in,
153+
int poolrank_in
117154
)
118155
{
119156
this->lat0 = lat0_in;
@@ -126,6 +163,8 @@ void PW_Basis:: initgrids(
126163
this->nz = nz_in;
127164
this->bignxy = this->nx * this->bigny;
128165
this->bignxyz = this->bignxy * this->nz;
166+
this->poolnproc = poolnproc_in;
167+
this->poolrank = poolrank_in;
129168

130169
return;
131170
}
@@ -135,8 +174,6 @@ void PW_Basis:: initgrids(
135174
void PW_Basis:: initparameters(
136175
bool gamma_only_in,
137176
double pwecut_in,
138-
int poolnproc_in,
139-
int poolrank_in,
140177
int distribution_type_in
141178
)
142179
{
@@ -150,8 +187,6 @@ void PW_Basis:: initparameters(
150187

151188
double tpiba2 = ModuleBase::TWO_PI * ModuleBase::TWO_PI / this->lat0 / this->lat0;
152189
this->ggecut = pwecut_in / tpiba2;
153-
this->poolnproc = poolnproc_in;
154-
this->poolrank = poolrank_in;
155190
this->distribution_type = distribution_type_in;
156191
}
157192

source/module_pw/test/test1-1.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ TEST(PWTEST,test1_1)
2020

2121
ModulePW::PW_Basis pwtest;
2222

23-
pwtest.initgrids(lat0, latvec, ecut);
24-
pwtest.initparameters(gamma_only, ecut, nproc_in_pool, rank_in_pool, distribution_type);
23+
pwtest.initgrids(lat0, latvec, ecut,nproc_in_pool, rank_in_pool);
24+
pwtest.initparameters(gamma_only, ecut, distribution_type);
2525
pwtest.distribute_r();
2626
pwtest.distribute_g();
2727

source/module_pw/test/test1-2.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ TEST_F(PWTEST,test1_2)
3131
//--------------------------------------------------
3232

3333
//init
34-
pwtest.initgrids(lat0,latvec,wfcecut);
34+
pwtest.initgrids(lat0,latvec,wfcecut, nproc_in_pool, rank_in_pool);
3535
//pwtest.initgrids(lat0,latvec,5,7,7);
36-
pwtest.initparameters(gamma_only,wfcecut,nproc_in_pool,rank_in_pool,distribution_type);
36+
pwtest.initparameters(gamma_only,wfcecut,distribution_type);
3737
pwtest.setuptransform();
3838
pwtest.collect_local_pw();
3939

source/module_pw/test/test1-2f.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ TEST_F(PWTEST,test1_2f)
3131
//--------------------------------------------------
3232

3333
//init
34-
pwtest.initgrids(lat0,latvec,wfcecut);
34+
pwtest.initgrids(lat0,latvec,wfcecut, nproc_in_pool, rank_in_pool);
3535
//pwtest.initgrids(lat0,latvec,5,7,7);
36-
pwtest.initparameters(gamma_only,wfcecut,nproc_in_pool,rank_in_pool,distribution_type);
36+
pwtest.initparameters(gamma_only,wfcecut,distribution_type);
3737
pwtest.setuptransform();
3838
pwtest.collect_local_pw();
3939

source/module_pw/test/test1-3.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ TEST_F(PWTEST,test1_3)
3131
//--------------------------------------------------
3232

3333
//init
34-
pwtest.initgrids(lat0,latvec,1.5*wfcecut);
34+
pwtest.initgrids(lat0,latvec,1.5*wfcecut, nproc_in_pool, rank_in_pool);
3535
//pwtest.initgrids(lat0,latvec,5,7,7);
36-
pwtest.initparameters(gamma_only,wfcecut,nproc_in_pool,rank_in_pool,distribution_type);
36+
pwtest.initparameters(gamma_only,wfcecut,distribution_type);
3737
pwtest.setuptransform();
3838
pwtest.collect_local_pw();
3939

source/module_pw/test/test1-3f.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ TEST_F(PWTEST,test1_3f)
3131
//--------------------------------------------------
3232

3333
//init
34-
pwtest.initgrids(lat0,latvec,1.5*wfcecut);
34+
pwtest.initgrids(lat0,latvec,1.5*wfcecut, nproc_in_pool, rank_in_pool);
3535
//pwtest.initgrids(lat0,latvec,5,7,7);
36-
pwtest.initparameters(gamma_only,wfcecut,nproc_in_pool,rank_in_pool,distribution_type);
36+
pwtest.initparameters(gamma_only,wfcecut,distribution_type);
3737
pwtest.setuptransform();
3838
pwtest.collect_local_pw();
3939

source/module_pw/test/test1-4.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ TEST_F(PWTEST,test1_4)
3636
int distribution_type = 1;
3737
//--------------------------------------------------
3838
//init
39-
pwtest.initgrids(lat0,latvec,wfcecut);
39+
pwtest.initgrids(lat0,latvec,wfcecut, nproc_in_pool, rank_in_pool);
4040
//pwtest.initgrids(lat0,latvec,5,7,7);
41-
pwtest.initparameters(gamma_only,wfcecut,nks,kvec_d,nproc_in_pool,rank_in_pool,distribution_type);
41+
pwtest.initparameters(gamma_only,wfcecut,nks,kvec_d,distribution_type);
4242
pwtest.setuptransform();
4343
pwtest.collect_local_pw();
4444

0 commit comments

Comments
 (0)