Skip to content

Commit 21f7818

Browse files
committed
fix compile
1 parent c5dc8ce commit 21f7818

File tree

14 files changed

+126
-109
lines changed

14 files changed

+126
-109
lines changed

source/module_base/kernels/math_kernel_op.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,11 +382,13 @@ template struct line_minimize_with_block_op<std::complex<float>, base_device::DE
382382

383383
template struct scal_op<double, base_device::DEVICE_CPU>;
384384
template struct axpy_op<std::complex<double>, base_device::DEVICE_CPU>;
385+
template struct axpy_op<double, base_device::DEVICE_CPU>;
385386
template struct gemv_op<std::complex<double>, base_device::DEVICE_CPU>;
386387
template struct gemv_op<double, base_device::DEVICE_CPU>;
387388
template struct gemm_op<std::complex<double>, base_device::DEVICE_CPU>;
388389
template struct gemm_op<double, base_device::DEVICE_CPU>;
389390
template struct dot_real_op<std::complex<double>, base_device::DEVICE_CPU>;
391+
template struct dot_real_op<double, base_device::DEVICE_CPU>;
390392
template struct vector_div_constant_op<std::complex<double>, base_device::DEVICE_CPU>;
391393
template struct vector_mul_vector_op<std::complex<double>, base_device::DEVICE_CPU>;
392394
template struct vector_div_vector_op<std::complex<double>, base_device::DEVICE_CPU>;
@@ -397,8 +399,6 @@ template struct calc_grad_with_block_op<std::complex<double>, base_device::DEVIC
397399
template struct line_minimize_with_block_op<std::complex<double>, base_device::DEVICE_CPU>;
398400

399401
#ifdef __LCAO
400-
template struct axpy_op<double, base_device::DEVICE_CPU>;
401-
template struct dot_real_op<double, base_device::DEVICE_CPU>;
402402
template struct vector_mul_vector_op<double, base_device::DEVICE_CPU>;
403403
template struct vector_div_constant_op<double, base_device::DEVICE_CPU>;
404404
template struct vector_div_vector_op<double, base_device::DEVICE_CPU>;

source/module_base/test_parallel/parallel_global_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class MPIContext
5555
int rank_in_pool;
5656

5757
int nstogroup;
58-
int my_stogroup;
58+
int MY_BNDGROUP;
5959
int rank_in_stogroup;
6060
int nproc_in_stogroup;
6161

@@ -173,7 +173,7 @@ TEST_F(ParaGlobal, InitPools)
173173
mpi.kpar,
174174
mpi.nproc_in_stogroup,
175175
mpi.rank_in_stogroup,
176-
mpi.my_stogroup,
176+
mpi.MY_BNDGROUP,
177177
mpi.nproc_in_pool,
178178
mpi.rank_in_pool,
179179
mpi.my_pool), ::testing::ExitedWithCode(1), "");

source/module_cell/cal_atoms_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class CalAtomsInfo
7070
elecstate::cal_nbands(para.inp.nelec, para.sys.nlocal, nelec_spin, para.input.nbands);
7171
// calculate the number of nbands_local
7272
para.sys.nbands_l = para.inp.nbands;
73-
if (inp.ks_solver == "bpcg") // only bpcg support band parallel
73+
if (para.inp.ks_solver == "bpcg") // only bpcg support band parallel
7474
{
7575
para.sys.nbands_l = para.inp.nbands / para.inp.bndpar;
7676
if (GlobalV::RANK_IN_BPGROUP < para.inp.nbands % para.inp.bndpar)

source/module_cell/test/klist_test_para.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ TEST_F(KlistParaTest, Set)
231231
GlobalV::KPAR,
232232
GlobalV::NPROC_IN_BNDGROUP,
233233
GlobalV::RANK_IN_BPGROUP,
234-
GlobalV::MY_STOGROUP,
234+
GlobalV::MY_BNDGROUP,
235235
GlobalV::NPROC_IN_POOL,
236236
GlobalV::RANK_IN_POOL,
237237
GlobalV::MY_POOL);
@@ -288,7 +288,7 @@ TEST_F(KlistParaTest, SetAfterVC)
288288
GlobalV::KPAR,
289289
GlobalV::NPROC_IN_BNDGROUP,
290290
GlobalV::RANK_IN_BPGROUP,
291-
GlobalV::MY_STOGROUP,
291+
GlobalV::MY_BNDGROUP,
292292
GlobalV::NPROC_IN_POOL,
293293
GlobalV::RANK_IN_POOL,
294294
GlobalV::MY_POOL);

source/module_elecstate/test/elecstate_print_test.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class ElecStatePrintTest : public ::testing::Test
9898
ucell.magnet.tot_magnetization_nc[1] = 4.4;
9999
ucell.magnet.tot_magnetization_nc[2] = 5.5;
100100
PARAM.input.ks_solver = "dav";
101+
PARAM.sys.log_file = "test.dat";
101102
}
102103
void TearDown()
103104
{
@@ -120,11 +121,11 @@ TEST_F(ElecStatePrintTest, PrintFormat)
120121
TEST_F(ElecStatePrintTest, PrintEigenvalueS2)
121122
{
122123
PARAM.input.nspin = 2;
123-
GlobalV::ofs_running.open("running_scf.log", std::ios::out);
124+
GlobalV::ofs_running.open("test.dat", std::ios::out);
124125
// print eigenvalue
125126
elecstate.print_eigenvalue(GlobalV::ofs_running);
126127
GlobalV::ofs_running.close();
127-
ifs.open("running_scf.log", std::ios::in);
128+
ifs.open("test.dat", std::ios::in);
128129
std::string str((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
129130
EXPECT_THAT(str, testing::HasSubstr("STATE ENERGY(eV) AND OCCUPATIONS"));
130131
EXPECT_THAT(str, testing::HasSubstr("NSPIN == 2"));
@@ -137,17 +138,17 @@ TEST_F(ElecStatePrintTest, PrintEigenvalueS2)
137138
EXPECT_THAT(str, testing::HasSubstr("1 40.8171 0.300000"));
138139
EXPECT_THAT(str, testing::HasSubstr("2 54.4228 0.400000"));
139140
ifs.close();
140-
std::remove("running_scf.log");
141+
std::remove("test.dat");
141142
}
142143

143144
TEST_F(ElecStatePrintTest, PrintEigenvalueS4)
144145
{
145146
PARAM.input.nspin = 4;
146-
GlobalV::ofs_running.open("running_scf.log", std::ios::out);
147+
GlobalV::ofs_running.open("test.dat", std::ios::out);
147148
// print eigenvalue
148149
elecstate.print_eigenvalue(GlobalV::ofs_running);
149150
GlobalV::ofs_running.close();
150-
ifs.open("running_scf.log", std::ios::in);
151+
ifs.open("test.dat", std::ios::in);
151152
std::string str((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
152153
EXPECT_THAT(str, testing::HasSubstr("STATE ENERGY(eV) AND OCCUPATIONS"));
153154
EXPECT_THAT(str, testing::HasSubstr("NSPIN == 4"));
@@ -158,51 +159,51 @@ TEST_F(ElecStatePrintTest, PrintEigenvalueS4)
158159
EXPECT_THAT(str, testing::HasSubstr("1 40.8171 0.300000"));
159160
EXPECT_THAT(str, testing::HasSubstr("2 54.4228 0.400000"));
160161
ifs.close();
161-
std::remove("running_scf.log");
162+
std::remove("test.dat");
162163
}
163164

164165
TEST_F(ElecStatePrintTest, PrintBand)
165166
{
166167
PARAM.input.nspin = 1;
167168
PARAM.input.nbands = 2;
168169
GlobalV::MY_RANK = 0;
169-
GlobalV::ofs_running.open("running_scf.log", std::ios::out);
170+
GlobalV::ofs_running.open("test.dat", std::ios::out);
170171
// print eigenvalue
171172
elecstate.print_band(0, 1, 0);
172173
GlobalV::ofs_running.close();
173-
ifs.open("running_scf.log", std::ios::in);
174+
ifs.open("test.dat", std::ios::in);
174175
std::string str((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
175176
EXPECT_THAT(str, testing::HasSubstr("Energy (eV) & Occupations for spin=1 K-point=1"));
176177
EXPECT_THAT(str, testing::HasSubstr("1 13.6057 0.100000"));
177178
EXPECT_THAT(str, testing::HasSubstr("2 27.2114 0.200000"));
178179
ifs.close();
179-
std::remove("running_scf.log");
180+
std::remove("test.dat");
180181
}
181182

182183
TEST_F(ElecStatePrintTest, PrintEigenvalueWarning)
183184
{
184185
elecstate.ekb(0, 0) = 1.0e11;
185186
PARAM.input.nspin = 4;
186-
GlobalV::ofs_running.open("running_scf.log", std::ios::out);
187+
GlobalV::ofs_running.open("test.dat", std::ios::out);
187188
testing::internal::CaptureStdout();
188189
EXPECT_EXIT(elecstate.print_eigenvalue(GlobalV::ofs_running), ::testing::ExitedWithCode(1), "");
189190
output = testing::internal::GetCapturedStdout();
190191
EXPECT_THAT(output, testing::HasSubstr("Eigenvalues are too large!"));
191192
GlobalV::ofs_running.close();
192-
std::remove("running_scf.log");
193+
std::remove("test.dat");
193194
}
194195

195196
TEST_F(ElecStatePrintTest, PrintBandWarning)
196197
{
197198
elecstate.ekb(0, 0) = 1.0e11;
198199
PARAM.input.nspin = 4;
199-
GlobalV::ofs_running.open("running_scf.log", std::ios::out);
200+
GlobalV::ofs_running.open("test.dat", std::ios::out);
200201
testing::internal::CaptureStdout();
201202
EXPECT_EXIT(elecstate.print_band(0, 1, 0), ::testing::ExitedWithCode(1), "");
202203
output = testing::internal::GetCapturedStdout();
203204
EXPECT_THAT(output, testing::HasSubstr("Eigenvalues are too large!"));
204205
GlobalV::ofs_running.close();
205-
std::remove("running_scf.log");
206+
std::remove("test.dat");
206207
}
207208

208209
TEST_F(ElecStatePrintTest, PrintEtot)

source/module_elecstate/test_mpi/charge_mpi_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1)
7272
GlobalV::KPAR,
7373
GlobalV::NPROC_IN_BNDGROUP,
7474
GlobalV::RANK_IN_BPGROUP,
75-
GlobalV::MY_STOGROUP,
75+
GlobalV::MY_BNDGROUP,
7676
GlobalV::NPROC_IN_POOL,
7777
GlobalV::RANK_IN_POOL,
7878
GlobalV::MY_POOL);
@@ -118,7 +118,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2)
118118
GlobalV::KPAR,
119119
GlobalV::NPROC_IN_BNDGROUP,
120120
GlobalV::RANK_IN_BPGROUP,
121-
GlobalV::MY_STOGROUP,
121+
GlobalV::MY_BNDGROUP,
122122
GlobalV::NPROC_IN_POOL,
123123
GlobalV::RANK_IN_POOL,
124124
GlobalV::MY_POOL);
@@ -173,7 +173,7 @@ TEST_F(ChargeMpiTest, rho_mpi)
173173
GlobalV::KPAR,
174174
GlobalV::NPROC_IN_BNDGROUP,
175175
GlobalV::RANK_IN_BPGROUP,
176-
GlobalV::MY_STOGROUP,
176+
GlobalV::MY_BNDGROUP,
177177
GlobalV::NPROC_IN_POOL,
178178
GlobalV::RANK_IN_POOL,
179179
GlobalV::MY_POOL);

source/module_hsolver/test/test_hsolver_sdft.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ class TestHSolverPW_SDFT : public ::testing::Test
252252
// stowf.nchip_max = 0;
253253
// psi_test_cd.resize(1, 2, 3);
254254
// PARAM.input.nelec = 1.0;
255-
// GlobalV::MY_STOGROUP = 0.0;
255+
// GlobalV::MY_BNDGROUP = 0.0;
256256
// int istep = 0;
257257
// int iter = 0;
258258

@@ -291,7 +291,7 @@ class TestHSolverPW_SDFT : public ::testing::Test
291291
// psi_test_no.nbands = 0;
292292
// psi_test_no.nbasis = 0;
293293
// PARAM.input.nelec = 1.0;
294-
// GlobalV::MY_STOGROUP = 0.0;
294+
// GlobalV::MY_BNDGROUP = 0.0;
295295
// PARAM.input.nspin = 1;
296296
// elecstate_test.charge = new Charge;
297297
// elecstate_test.charge->rho = new double*[1];

source/module_io/read_input.cpp

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "module_base/global_function.h"
1414
#include "module_base/tool_quit.h"
1515
#include "module_base/tool_title.h"
16+
#include "module_base/module_device/device.h"
1617
namespace ModuleIO
1718
{
1819

@@ -112,39 +113,54 @@ void ReadInput::read_parameters(Parameter& param, const std::string& filename_in
112113
// 1. only rank 0 read the input file
113114
if (this->rank == 0)
114115
{
115-
// 1. read the input file
116116
// We can also easily add other input file formats here
117117
this->read_txt_input(param, filename_in);
118-
119-
// 2. check the value of the parameters
120-
for (auto& input_item: this->input_lists)
121-
{
122-
Input_Item* checkvalue_item = &(input_item.second);
123-
if (checkvalue_item->check_value != nullptr)
124-
{
125-
checkvalue_item->check_value(*checkvalue_item, param);
126-
}
127-
}
128118
}
129119

130-
// 3. check the number of atom types from STRU file
120+
// 2. check the number of atom types from STRU file
131121
// set the global directories
132122
this->set_global_dir(param.inp, param.sys);
133123
if (this->check_ntype_flag && this->rank == 0)
134124
{
135125
check_ntype(param.globalv.global_in_stru, param.input.ntype);
136126
}
137127

138-
// 4. broadcast input parameters
128+
// 3. broadcast input parameters
139129
// It must be after the check_ntype, because some parameters need to be filled due to ntype
140130
for (auto& bcastfunc: this->bcastfuncs)
141131
{
142132
bcastfunc(param);
143133
}
144134

145-
// 5. set the globalv parameters, some parameters in different processes are different. e.g. rank
135+
// 4. set the globalv parameters, some parameters in different processes are different. e.g. rank, log_file
146136
this->set_globalv(param.inp, param.sys);
147137

138+
// 5. check the value of the parameters
139+
// It must be after the check_ntype, because some parameters need to be checked according to ntype
140+
// It must be after the set_globalv, because some parameters need to be checked according to param.sys
141+
if (this->rank == 0)
142+
{
143+
for (auto& input_item: this->input_lists)
144+
{
145+
Input_Item* checkvalue_item = &(input_item.second);
146+
if (checkvalue_item->check_value != nullptr)
147+
{
148+
checkvalue_item->check_value(*checkvalue_item, param);
149+
}
150+
}
151+
}
152+
153+
// 6. check and reset kpar.
154+
// It must be after bcastfunc, and kpar and bndpar are synchronized
155+
// It must be before wirte_txt_input, because kpar is used in write_txt_input
156+
if (param.inp.device == "gpu" && param.inp.basis_type == "pw")
157+
{
158+
param.input.kpar = base_device::information::get_device_kpar(param.inp.kpar, param.inp.bndpar);
159+
}
160+
161+
162+
163+
148164
if (this->check_mode)
149165
{
150166
std::cout << "----------------------------------------------------------" << std::endl;

source/module_io/read_input_item_system.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,8 @@ void ReadInput::item_system()
234234
"will be distributed among";
235235
read_sync_int(input.kpar);
236236
item.reset_value = [](const Input_Item& item, Parameter& para) {
237-
if (para.inp.device == "gpu" && para.inp.basis_type == "pw")
238-
{
239-
para.input.kpar = base_device::information::get_device_kpar(para.inp.kpar, para.inp.bndpar);
240-
}
241237
#ifdef __LCAO
242-
else if (para.inp.basis_type == "lcao")
238+
if (para.inp.basis_type == "lcao")
243239
{
244240
para.sys.kpar_lcao = para.inp.kpar;
245241
para.input.kpar = 1;

0 commit comments

Comments
 (0)