Skip to content

Commit f433ed1

Browse files
committed
Merge planewave branch
2 parents 10c2079 + 477a2a5 commit f433ed1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2046
-2305
lines changed

source/module_pw/fft.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,16 @@ class FFT
2727

2828
FFT();
2929
~FFT();
30+
31+
// init parameters of fft
3032
void initfft(int nx_in, int bigny_in, int nz_in, int liy_in, int riy_in, int ns_in, int nplane_in,
3133
int nproc_in, bool gamma_only_in, bool mpifft_in = false);
32-
void setupFFT();
33-
void cleanFFT();
34+
35+
//init fftw_plans
36+
void setupFFT();
37+
38+
//destroy fftw_plans
39+
void cleanFFT();
3440

3541
void fftzfor(std::complex<double>* & in, std::complex<double>* & out);
3642
void fftzbac(std::complex<double>* & in, std::complex<double>* & out);
@@ -50,10 +56,12 @@ class FFT
5056
#endif
5157

5258
public:
53-
void initplan();
59+
//init fftw_plans
60+
void initplan();
5461
void initplan_mpi();
5562
#ifdef __MIX_PRECISION
56-
void initplanf();
63+
//init fftwf_plans
64+
void initplanf();
5765
void initplanf_mpi();
5866
#endif
5967

source/module_pw/pw_basis.cpp

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ PW_Basis:: ~PW_Basis()
5252
if(startr != NULL) delete[] startr;
5353
}
5454

55+
///
56+
/// distribute plane wave basis and real-space grids to different processors
57+
/// set up maps for fft and create arrays for MPI_Alltoall
58+
/// set up ffts
59+
///
5560
void PW_Basis::setuptransform()
5661
{
5762
this->distribute_r();
@@ -96,14 +101,11 @@ void PW_Basis::getstartgr()
96101
return;
97102
}
98103

99-
//
100-
// Collect planewaves on current core, and construct gg, gdirect, gcar according to ig2isz and is2ixy.
101-
// is2ixy contains the x-coordinate and y-coordinate of sticks on current core.
102-
// ig2isz contains the z-coordinate of planewaves on current core.
103-
// We will scan the sticks on current core and find the planewaves on them, then store the information into corresponding arrays.
104-
// known: ig2isz, is2ixy
105-
// output: gg, gdirect, gcar
106-
//
104+
///
105+
/// Collect planewaves on current core, and construct gg, gdirect, gcar according to ig2isz and is2ixy.
106+
/// known: ig2isz, is2ixy
107+
/// output: gg, gdirect, gcar
108+
///
107109
void PW_Basis::collect_local_pw()
108110
{
109111
if(gg != NULL) delete[] gg;
@@ -114,34 +116,24 @@ void PW_Basis::collect_local_pw()
114116
this->gcar = new ModuleBase::Vector3<double>[this->npw];
115117

116118
ModuleBase::Vector3<double> f;
117-
int pw_filled = 0; // how many current core's planewaves have been found.
118-
for (int is = 0; is < this->nst; ++is)
119+
for(int ig = 0 ; ig < this-> npw ; ++ig)
119120
{
120-
int ix = this->is2ixy[is] / this->ny;
121-
int iy = this->is2ixy[is] % this->ny;
121+
int isz = this->ig2isz[ig];
122+
int iz = isz % this->nz;
123+
int is = isz / this->nz;
124+
int ixy = this->is2ixy[is];
125+
int ix = ixy / this->ny;
126+
int iy = ixy % this->ny;
122127
if (ix >= int(this->nx/2) + 1) ix -= this->nx;
123128
if (iy >= int(this->bigny/2) + 1) iy -= this->bigny;
124-
for (int ig = pw_filled; ig < this->npw; ++ig)
125-
{
126-
if (this->ig2isz[ig] < (is + 1) * this->nz) // meaning this pw belongs to is^th sticks.
127-
{
128-
int iz = this->ig2isz[ig] % this->nz;
129-
if (iz >= int(this->nz/2) + 1) iz -= this->nz;
130-
f.x = ix;
131-
f.y = iy;
132-
f.z = iz;
133-
this->gg[pw_filled] = f * (this->GGT * f);
134-
this->gdirect[pw_filled] = f;
135-
this->gcar[pw_filled] = f * this->G;
136-
pw_filled++;
137-
}
138-
else
139-
{
140-
break;
141-
}
142-
}
129+
if (iz >= int(this->nz/2) + 1) iz -= this->nz;
130+
f.x = ix;
131+
f.y = iy;
132+
f.z = iz;
133+
this->gg[ig] = f * (this->GGT * f);
134+
this->gdirect[ig] = f;
135+
this->gcar[ig] = f * this->G;
143136
}
144-
assert(pw_filled == this->npw);
145137
return;
146138
}
147139

source/module_pw/pw_basis.h

Lines changed: 53 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
namespace ModulePW
1111
{
1212

13-
//
14-
//A class which can convert a function of "r" to the corresponding linear
15-
// superposition of plane waves (real space to reciprocal space)
16-
// or convert a linear superposition of plane waves to the function
17-
// of "r" (reciprocal to real).
18-
//plane waves: <r|g>=1/sqrt(V) * exp(igr)
19-
// f(r) = 1/sqrt(V) * \sum_g{c(g)*exp(igr)}
20-
//
13+
///
14+
/// A class which can convert a function of "r" to the corresponding linear
15+
/// superposition of plane waves (real space to reciprocal space)
16+
/// or convert a linear superposition of plane waves to the function
17+
/// of "r" (reciprocal to real).
18+
/// plane waves: <r|g>=1/sqrt(V) * exp(igr)
19+
/// f(r) = 1/sqrt(V) * \sum_g{c(g)*exp(igr)}
20+
///
2121
class PW_Basis
2222
{
2323

@@ -29,21 +29,23 @@ class PW_Basis
2929
void initgrids(
3030
double lat0_in, //unit length (unit in bohr)
3131
ModuleBase::Matrix3 latvec_in, // Unitcell lattice vectors (unit in lat0)
32-
double gridecut //unit in Ry, ecut to set up grids
32+
double gridecut, //unit in Ry, ecut to set up grids
33+
int poolnproc_in, // Number of processors in this pool
34+
int poolrank_in // Rank in this pool
3335
);
3436
//Init the grids for FFT
3537
void initgrids(
3638
double lat0_in,
3739
ModuleBase::Matrix3 latvec_in, // Unitcell lattice vectors
38-
int nx_in, int bigny_in, int nz_in
40+
int nx_in, int bigny_in, int nz_in,
41+
int poolnproc_in, // Number of processors in this pool
42+
int poolrank_in // Rank in this pool
3943
);
4044

4145
//Init some parameters
4246
void initparameters(
4347
bool gamma_only_in,
4448
double pwecut_in, //unit in Ry, ecut to decides plane waves
45-
int poolnproc_in, // Number of processors in this pool
46-
int poolrank_in, // Rank in this pool
4749
int distribution_type_in
4850
);
4951

@@ -53,15 +55,15 @@ class PW_Basis
5355
public:
5456
//reciprocal-space
5557
// only on first proc.
56-
int *startnsz_per; // startnsz_per[ip]: starting is * nz stick in the ip^th proc.
57-
int *nstnz_per; // nz * nst(number of sticks) on each core.
58+
int *startnsz_per;//useless // startnsz_per[ip]: starting is * nz stick in the ip^th proc.
59+
int *nstnz_per;//useless // nz * nst(number of sticks) on each core.
5860
int *nst_per;// nst on each core
5961
// on all proc.
6062
int *ig2isz; // map ig to (is, iz).
6163
int *istot2bigixy; // istot2bigixy[is]: iy + ix * bigny of is^th stick among all sticks.
62-
int *ixy2istot; // ixy2istot[ix + iy * nx]: is of stick on (ix, iy) among all sticks.
63-
int *is2ixy; // is2ixy[is]: ix + iy * bignx of is^th stick among sticks on current proc.
64-
int *ixy2ip; // ixy2ip[ix + iy * nx]: ip of proc which contains stick on (ix, iy).
64+
int *ixy2istot; //useless // ixy2istot[iy + ix * ny]: is of stick on (ix, iy) among all sticks.
65+
int *is2ixy; // is2ixy[is]: iy + ix * bigny of is^th stick among sticks on current proc.
66+
int *ixy2ip; // useless// ixy2ip[iy + ix * ny]: ip of proc which contains stick on (ix, iy).
6567
int nst; //num. of sticks in current proc.
6668
int nstnz; // nst * nz
6769
int nstot; //num. of sticks in total.
@@ -91,13 +93,12 @@ class PW_Basis
9193
//distribute real-space grids to different processors
9294
void distribute_r();
9395

96+
//prepare for MPI_Alltoall
9497
void getstartgr();
9598

96-
//distribute plane waves to different processors
97-
void distribution_method1(); // x varies fast
98-
void distribution_method2(); // sticks sorted according to ixy
99-
// void distribution_method3(); // y varies fast
10099

100+
101+
//prepare for transforms between real and reciprocal spaces
101102
void collect_local_pw();
102103

103104
// void collect_tot_pw(
@@ -118,64 +119,56 @@ class PW_Basis
118119
int distribution_type;
119120
int poolnproc;
120121
int poolrank;
121-
122+
//distribute plane waves to different processors
123+
124+
//method 1: first consider number of plane waves
125+
void distribution_method1();
126+
// Distribute sticks to cores in method 1.
127+
void divide_sticks_1(
128+
int* st_i, // x or x + nx (if x < 0) of stick.
129+
int* st_j, // y or y + ny (if y < 0) of stick.
130+
int* st_length, // the stick on (ix, iy) consists of st_length[ix*ny+iy] planewaves.
131+
int* npw_per // number of planewaves on each core.
132+
);
122133

123-
// for both distributeg_method1 and distributeg_method2
134+
//method 2: first consider number of sticks
135+
void distribution_method2();
136+
// Distribute sticks to cores in method 2.
137+
void divide_sticks_2();
138+
139+
//Count the total number of planewaves (tot_npw) and sticks (this->nstot) (in distributeg method1 and method2)
124140
void count_pw_st(
125141
int &tot_npw, // total number of planewaves.
126142
int* st_length2D, // the number of planewaves that belong to the stick located on (x, y).
127143
int* st_bottom2D // the z-coordinate of the bottom of stick on (x, y).
128144
);
145+
146+
//get ig2isz and is2ixy
129147
void get_ig2isz_is2ixy(
130148
int* st_bottom, // minimum z of stick, stored in 1d array with tot_nst elements.
131149
int* st_length // the stick on (x, y) consists of st_length[x*ny+y] planewaves.
132150
);
133-
// for distributeg_method1
151+
152+
//Collect the x, y indexs, length of the sticks (in distributeg method1)
134153
void collect_st(
135154
int* st_length2D, // the number of planewaves that belong to the stick located on (x, y), stored in 2d x-y plane.
136155
int* st_bottom2D, // the z-coordinate of the bottom of stick on (x, y), stored in 2d x-y plane.
137156
int* st_i, // x or x + nx (if x < 0) of stick.
138157
int* st_j, // y or y + ny (if y < 0) of stick.
139158
int* st_length // number of planewaves in stick, stored in 1d array with tot_nst elements.
140159
);
141-
void divide_sticks(
142-
int* st_i, // x or x + nx (if x < 0) of stick.
143-
int* st_j, // y or y + ny (if y < 0) of stick.
144-
int* st_length, // the stick on (x, y) consists of st_length[x*ny+y] planewaves.
145-
int* npw_per // number of planewaves on each core.
146-
);
160+
161+
//get istot2bigixy
147162
void get_istot2bigixy(
148163
int* st_i, // x or x + nx (if x < 0) of stick.
149164
int* st_j // y or y + ny (if y < 0) of stick.
150165
);
151-
// for distributeg_method2
152-
void divide_sticks2();
166+
167+
//Create the maps from ixy to (in method 2)
153168
void create_maps(
154169
int* st_length2D, // the number of planewaves that belong to the stick located on (x, y), stored in 2d x-y plane.
155170
int* npw_per // number of planewaves on each core.
156171
);
157-
// for distributeg_method3
158-
// void divide_sticks2(
159-
// const int tot_npw, // total number of planewaves.
160-
// int* st_i, // x or x + nx (if x < 0) of stick.
161-
// int* st_j, // y or y + ny (if y < 0) of stick.
162-
// int* st_length, // the stick on (x, y) consists of st_length[x*ny+y] planewaves.
163-
// int* npw_per, // number of planewaves on each core.
164-
// int* nst_per, // number of sticks on each core.
165-
// int* is2ip // ip of core containing is^th stick, map is to ip.
166-
// );
167-
// void get_istot2ixy2(
168-
// int* st_i, // x or x + nx (if x < 0) of stick.
169-
// int* st_j, // y or y + ny (if y < 0) of stick.
170-
// int* is2ip // ip of core containing is^th stick, map is to ip.
171-
// );
172-
// void get_ig2isz_is2ixy2(
173-
// int* st_i, // x or x + nx (if x < 0) of stick.
174-
// int* st_j, // y or y + ny (if y < 0) of stick.
175-
// int* st_bottom, // minimum z of stick, stored in 1d array with tot_nst elements.
176-
// int* st_length, // the stick on (x, y) consists of st_length[x*ny+y] planewaves.
177-
// int* is2ip // ip of core containing is^th stick, map is to ip.
178-
// );
179172

180173
//===============================================
181174
// FFT
@@ -199,16 +192,14 @@ class PW_Basis
199192
void recip2real(std::complex<float> * in, float *out); //in:(nz, ns) ; out(nplane,nx*ny)
200193
void recip2real(std::complex<float> * in, std::complex<float> * out); //in:(nz, ns) ; out(nplane,nx*ny)
201194
#endif
195+
//gather planes and scatter sticks of all processors
202196
template<typename T>
203-
void gatherp_scatters(std::complex<T> *in, std::complex<T> *out); //gather planes and scatter sticks of all processors
204-
template<typename T>
205-
void gathers_scatterp(std::complex<T> *in, std::complex<T> *out); //gather sticks of and scatter planes of all processors
206-
// void gathers_scatterp2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
207-
// void gatherp_scatters2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
208-
// void gatherp_scatters_gamma(std::complex<double> *in, std::complex<double> *out); //gather planes and scatter sticks of all processors, used when gamma_only
209-
// void gathers_scatterp_gamma(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors, used when gamma only
197+
void gatherp_scatters(std::complex<T> *in, std::complex<T> *out);
210198

199+
//gather sticks of and scatter planes of all processors
200+
template<typename T>
201+
void gathers_scatterp(std::complex<T> *in, std::complex<T> *out);
211202
};
212203

213204
}
214-
#endif //PlaneWave class
205+
#endif //PlaneWave

source/module_pw/pw_basis_k.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ void PW_Basis_K:: initparameters(
2424
double gk_ecut_in,
2525
int nks_in, //number of k points in this pool
2626
ModuleBase::Vector3<double> *kvec_d_in, // Direct coordinates of k points
27-
int poolnproc_in, // Number of processors in this pool
28-
int poolrank_in, // Rank in this pool
2927
int distribution_type_in
3028
)
3129
{
@@ -53,8 +51,6 @@ void PW_Basis_K:: initparameters(
5351
this->nxy = this->nx * this->ny;
5452
this->nxyz = this->nxy * this->nz;
5553

56-
this->poolnproc = poolnproc_in;
57-
this->poolrank = poolrank_in;
5854
this->distribution_type = distribution_type_in;
5955
return;
6056
}
@@ -103,6 +99,12 @@ void PW_Basis_K::setupIndGk()
10399

104100
return;
105101
}
102+
103+
///
104+
/// distribute plane wave basis and real-space grids to different processors
105+
/// set up maps for fft and create arrays for MPI_Alltoall
106+
/// set up ffts
107+
///
106108
void PW_Basis_K::setuptransform()
107109
{
108110
this->distribute_r();

0 commit comments

Comments
 (0)