Skip to content

Commit b8974ec

Browse files
committed
Refactor ModuleIO::read_cube_core()
1 parent 8dabb3f commit b8974ec

File tree

2 files changed

+76
-100
lines changed

2 files changed

+76
-100
lines changed

source/module_io/cube_io.h

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,6 @@ extern bool read_cube(
2828
int& prenspin,
2929
const bool warning_flag = true);
3030

31-
extern void read_cube_core(
32-
std::ifstream &ifs,
33-
#ifdef __MPI
34-
const Parallel_Grid*const Pgrid,
35-
#endif
36-
const int my_rank,
37-
const std::string esolver_type,
38-
const int rank_in_stogroup,
39-
#ifdef __MPI
40-
#else
41-
const int is,
42-
std::ofstream& ofs_running,
43-
#endif
44-
double*const data,
45-
const int nx,
46-
const int ny,
47-
const int nz,
48-
const int nx_read,
49-
const int ny_read,
50-
const int nz_read);
51-
5231
extern void write_cube(
5332
#ifdef __MPI
5433
const int bz,
@@ -69,6 +48,31 @@ extern void write_cube(
6948
const int precision = 11,
7049
const int out_fermi = 1); // mohan add 2007-10-17
7150

51+
52+
extern void read_cube_core_match(
53+
std::ifstream &ifs,
54+
#ifdef __MPI
55+
const Parallel_Grid*const Pgrid,
56+
const bool flag_read_rank,
57+
#endif
58+
double*const data,
59+
const int nxy,
60+
const int nz);
61+
62+
extern void read_cube_core_mismatch(
63+
std::ifstream &ifs,
64+
#ifdef __MPI
65+
const Parallel_Grid*const Pgrid,
66+
const bool flag_read_rank,
67+
#endif
68+
double*const data,
69+
const int nx,
70+
const int ny,
71+
const int nz,
72+
const int nx_read,
73+
const int ny_read,
74+
const int nz_read);
75+
7276
extern void write_cube_core(
7377
std::ofstream &ofs_cube,
7478
#ifdef __MPI
@@ -120,7 +124,7 @@ extern void write_cube_core(
120124
const int& ny,
121125
const int& nz,
122126
#ifdef __MPI
123-
double** data
127+
std::vector<std::vector<double>> &data
124128
#else
125129
double* data
126130
#endif

source/module_io/read_cube.cpp

Lines changed: 50 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -112,29 +112,61 @@ bool ModuleIO::read_cube(
112112
}
113113
}
114114

115+
const bool flag_read_rank = (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0));
115116
#ifdef __MPI
116-
ModuleIO::read_cube_core(ifs, Pgrid, my_rank, esolver_type, rank_in_stogroup, data, nx, ny, nz, nx_read, ny_read, nz_read);
117+
if(nx == nx_read && ny == ny_read && nz == nz_read)
118+
ModuleIO::read_cube_core_match(ifs, Pgrid, flag_read_rank, data, nx*ny, nz);
119+
else
120+
ModuleIO::read_cube_core_mismatch(ifs, Pgrid, flag_read_rank, data, nx, ny, nz, nx_read, ny_read, nz_read);
117121
#else
118-
ModuleIO::read_cube_core(ifs, my_rank, esolver_type, rank_in_stogroup, is, ofs_running, data, nx, ny, nz, nx_read, ny_read, nz_read);
122+
ofs_running << " Read SPIN = " << is + 1 << " charge now." << std::endl;
123+
if(nx == nx_read && ny == ny_read && nz == nz_read)
124+
ModuleIO::read_cube_core_match(ifs, data, nx*ny, nz);
125+
else
126+
ModuleIO::read_cube_core_mismatch(ifs, data, nx, ny, nz, nx_read, ny_read, nz_read);
119127
#endif
120128

121-
if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0))
122-
ifs.close();
123129
return true;
124130
}
125131

126-
void ModuleIO::read_cube_core(
132+
void ModuleIO::read_cube_core_match(
127133
std::ifstream &ifs,
128134
#ifdef __MPI
129135
const Parallel_Grid*const Pgrid,
136+
const bool flag_read_rank,
130137
#endif
131-
const int my_rank,
132-
const std::string esolver_type,
133-
const int rank_in_stogroup,
138+
double*const data,
139+
const int nxy,
140+
const int nz)
141+
{
134142
#ifdef __MPI
143+
if (flag_read_rank)
144+
{
145+
std::vector<std::vector<double>> read_rho(nz, std::vector<double>(nxy));
146+
for (int ixy = 0; ixy < nxy; ixy++)
147+
for (int iz = 0; iz < nz; iz++)
148+
ifs >> read_rho[iz][ixy];
149+
for (int iz = 0; iz < nz; iz++)
150+
Pgrid->zpiece_to_all(read_rho[iz].data(), iz, data);
151+
}
152+
else
153+
{
154+
std::vector<double> zpiece(nxy);
155+
for (int iz = 0; iz < nz; iz++)
156+
Pgrid->zpiece_to_all(zpiece.data(), iz, data);
157+
}
135158
#else
136-
const int is,
137-
std::ofstream& ofs_running,
159+
for (int ixy = 0; ixy < nxy; ixy++)
160+
for (int iz = 0; iz < nz; iz++)
161+
ifs >> data[iz * nxy + ixy];
162+
#endif
163+
}
164+
165+
void ModuleIO::read_cube_core_mismatch(
166+
std::ifstream &ifs,
167+
#ifdef __MPI
168+
const Parallel_Grid*const Pgrid,
169+
const bool flag_read_rank,
138170
#endif
139171
double*const data,
140172
const int nx,
@@ -144,83 +176,23 @@ void ModuleIO::read_cube_core(
144176
const int ny_read,
145177
const int nz_read)
146178
{
147-
const bool same = (nx == nx_read && ny == ny_read && nz == nz_read) ? true : false;
148-
149179
#ifdef __MPI
150180
const int nxy = nx * ny;
151-
double* zpiece = nullptr;
152-
double** read_rho = nullptr;
153-
if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0))
181+
if (flag_read_rank)
154182
{
155-
read_rho = new double*[nz];
183+
std::vector<std::vector<double>> read_rho(nz, std::vector<double>(nxy));
184+
ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, read_rho);
156185
for (int iz = 0; iz < nz; iz++)
157-
{
158-
read_rho[iz] = new double[nxy];
159-
}
160-
if (same)
161-
{
162-
for (int ix = 0; ix < nx; ix++)
163-
{
164-
for (int iy = 0; iy < ny; iy++)
165-
{
166-
for (int iz = 0; iz < nz; iz++)
167-
{
168-
ifs >> read_rho[iz][ix * ny + iy];
169-
}
170-
}
171-
}
172-
}
173-
else
174-
{
175-
ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, read_rho);
176-
}
186+
Pgrid->zpiece_to_all(read_rho[iz].data(), iz, data);
177187
}
178188
else
179189
{
180-
zpiece = new double[nxy];
181-
ModuleBase::GlobalFunc::ZEROS(zpiece, nxy);
182-
}
183-
184-
for (int iz = 0; iz < nz; iz++)
185-
{
186-
if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0))
187-
{
188-
zpiece = read_rho[iz];
189-
}
190-
Pgrid->zpiece_to_all(zpiece, iz, data);
191-
} // iz
192-
193-
if (my_rank == 0 || (esolver_type == "sdft" && rank_in_stogroup == 0))
194-
{
190+
std::vector<double> zpiece(nxy);
195191
for (int iz = 0; iz < nz; iz++)
196-
{
197-
delete[] read_rho[iz];
198-
}
199-
delete[] read_rho;
200-
}
201-
else
202-
{
203-
delete[] zpiece;
192+
Pgrid->zpiece_to_all(zpiece.data(), iz, data);
204193
}
205194
#else
206-
ofs_running << " Read SPIN = " << is + 1 << " charge now." << std::endl;
207-
if (same)
208-
{
209-
for (int i = 0; i < nx; i++)
210-
{
211-
for (int j = 0; j < ny; j++)
212-
{
213-
for (int k = 0; k < nz; k++)
214-
{
215-
ifs >> data[k * nx * ny + i * ny + j];
216-
}
217-
}
218-
}
219-
}
220-
else
221-
{
222-
ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, data);
223-
}
195+
ModuleIO::trilinear_interpolate(ifs, nx_read, ny_read, nz_read, nx, ny, nz, data);
224196
#endif
225197
}
226198

@@ -232,7 +204,7 @@ void ModuleIO::trilinear_interpolate(std::ifstream& ifs,
232204
const int& ny,
233205
const int& nz,
234206
#ifdef __MPI
235-
double** data
207+
std::vector<std::vector<double>> &data
236208
#else
237209
double* data
238210
#endif

0 commit comments

Comments
 (0)