Skip to content

Commit ada3a8d

Browse files
committed
fix bug in the mpi set
1 parent 88b3605 commit ada3a8d

File tree

7 files changed

+17
-18
lines changed

7 files changed

+17
-18
lines changed

source/module_basis/module_pw/pw_basis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ class PW_Basis
437437

438438
std::string device = "cpu"; ///< cpu or gpu
439439
std::string precision = "double"; ///< single, double, mixing
440-
bool mpi_flag_ = true; ///< ture,is use mpi or not
440+
bool mpi_flag_ = true; ///< ture,is use mpi or not
441441
bool double_data_ = true; ///< if has double data
442442
bool float_data_ = false; ///< if has float data
443443
};

source/module_basis/module_pw/test_gpu/pw_test.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,8 @@ int main(int argc, char **argv)
4444
precision_flag = "double";
4545
#endif
4646
device_flag = "cpu";
47-
#ifdef __MPI
48-
int nproc, myrank ,mypool;
49-
MPI_Init(&argc,&argv);
50-
MPI_Comm_size(MPI_COMM_WORLD,&nproc);
51-
MPI_Comm_rank(MPI_COMM_WORLD,&myrank);
52-
#else
5347
nproc_in_pool = kpar = 1;
5448
rank_in_pool = 0;
55-
#endif
5649
#ifdef _OPENMP
5750
// ref: https://www.fftw.org/fftw3_doc/Usage-of-Multi_002dthreaded-FFTW.html
5851
fftw_init_threads();
@@ -62,9 +55,6 @@ int main(int argc, char **argv)
6255
testing::AddGlobalTestEnvironment(new TestEnv);
6356
testing::InitGoogleTest(&argc, argv);
6457
result = RUN_ALL_TESTS();
65-
#ifdef __MPI
66-
MPI_Finalize();
67-
#endif
6858
#ifdef _OPENMP
6959
fftw_cleanup_threads();
7060
#endif

source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@ class PWBasisKTEST: public testing::Test
3636
std::string precision_double = "double";
3737
std::string precision_single = "single";
3838
std::string device_flag = "cpu";
39+
ModulePW::PW_Basis_K basis_k;
40+
void SetUp()
41+
{
42+
// basis_k= ModuleBase::PW_Basis_k(device_flag,preci)
43+
basis_k.set_mpi(false);
44+
}
3945
};
4046

4147
TEST_F(PWBasisKTEST,Constructor)
4248
{
43-
ModulePW::PW_Basis_K basis_k1;
4449
ModulePW::PW_Basis_K basis_k2(device_flag, precision_double);
45-
EXPECT_EQ(basis_k1.classname,"PW_Basis_K");
4650
EXPECT_EQ(basis_k2.classname,"PW_Basis_K");
4751
EXPECT_EQ(basis_k2.device,"cpu");
4852
EXPECT_EQ(basis_k2.fft_bundle.device,"cpu");
@@ -54,7 +58,6 @@ TEST_F(PWBasisKTEST,Constructor)
5458

5559
TEST_F(PWBasisKTEST,Initgrids1)
5660
{
57-
ModulePW::PW_Basis_K basis_k;
5861
double lat0 = 1.8897261254578281;
5962
ModuleBase::Matrix3 latvec(10.0,0.0,0.0,
6063
0.0,10.0,0.0,
@@ -80,7 +83,6 @@ TEST_F(PWBasisKTEST,Initgrids1)
8083

8184
TEST_F(PWBasisKTEST,Initgrids2)
8285
{
83-
ModulePW::PW_Basis_K basis_k;
8486
double lat0 = 1.8897261254578281;
8587
ModuleBase::Matrix3 latvec(10.0,0.0,0.0,
8688
0.0,10.0,0.0,
@@ -105,7 +107,6 @@ TEST_F(PWBasisKTEST,Initgrids2)
105107

106108
TEST_F(PWBasisKTEST, Initparameters)
107109
{
108-
ModulePW::PW_Basis_K basis_k(device_flag, precision_single);
109110
double lat0 = 1.8897261254578281;
110111
ModuleBase::Matrix3 latvec(10.0,0.0,0.0,
111112
0.0,10.0,0.0,
@@ -150,7 +151,6 @@ TEST_F(PWBasisKTEST, Initparameters)
150151

151152
TEST_F(PWBasisKTEST, SetupTransform)
152153
{
153-
ModulePW::PW_Basis_K basis_k(device_flag, precision_double);
154154
double lat0 = 1.8897261254578281;
155155
ModuleBase::Matrix3 latvec(10.0,0.0,0.0,
156156
0.0,10.0,0.0,
@@ -170,7 +170,6 @@ TEST_F(PWBasisKTEST, SetupTransform)
170170

171171
TEST_F(PWBasisKTEST, CollectLocalPW)
172172
{
173-
ModulePW::PW_Basis_K basis_k(device_flag, precision_double);
174173
double lat0 = 1.8897261254578281;
175174
ModuleBase::Matrix3 latvec(10.0,0.0,0.0,
176175
0.0,10.0,0.0,

source/module_basis/module_pw/test_serial/pw_basis_test.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ class PWBasisTEST: public testing::Test
4848
std::string device_flag = "cpu";
4949
ModulePW::PW_Basis pwb;
5050
ModulePW::PW_Basis pwb1;
51+
void SetUp()
52+
{
53+
pwb.set_mpi(false);
54+
}
5155
};
5256

5357
TEST_F(PWBasisTEST,Constructor)

source/module_elecstate/test/charge_mixing_test.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,12 @@ class ChargeMixingTest : public ::testing::Test
8282
ChargeMixingTest()
8383
{
8484
// Init pw_basis
85+
pw_basis.set_mpi(false);
8586
pw_basis.initgrids(4, ModuleBase::Matrix3(1, 0, 0, 0, 1, 0, 0, 0, 1), 20);
8687
pw_basis.initparameters(false, 20);
8788
pw_basis.setuptransform();
8889
pw_basis.collect_local_pw();
90+
pw_dbasis.set_mpi(false);
8991
pw_dbasis.initgrids(4, ModuleBase::Matrix3(1, 0, 0, 0, 1, 0, 0, 0, 1), 40);
9092
pw_dbasis.initparameters(false, 40);
9193
pw_dbasis.setuptransform(&pw_basis);

source/module_elecstate/test/charge_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class ChargeTest : public ::testing::Test
8080
ucell = utp.SetUcellInfo();
8181
charge = new Charge;
8282
rhopw = new ModulePW::PW_Basis;
83+
rhopw->set_mpi(false);
8384
rhopw->initgrids(ucell->lat0, ucell->latvec, elecstate::tmp_gridecut);
8485
rhopw->distribute_r();
8586
rhopw->initparameters(false, elecstate::tmp_gridecut);

source/module_elecstate/test/potential_new_test.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class PotentialNewTest : public ::testing::Test
117117
solvent = new surchem();
118118
etxc = new double;
119119
vtxc = new double;
120+
rhopw->set_mpi(false);
120121
elecstate::Set_GlobalV_Default();
121122
}
122123
virtual void TearDown()
@@ -578,6 +579,7 @@ TEST_F(PotentialNewTest, InterpolateVrsDoubleGrids)
578579
XC_Functional::func_type = 3;
579580
XC_Functional::ked_flag = true;
580581
// Init pw_basis
582+
rhodpw->set_mpi(false);
581583
rhopw->initgrids(4, ModuleBase::Matrix3(1, 0, 0, 0, 1, 0, 0, 0, 1), 4);
582584
rhopw->initparameters(false, 4);
583585
rhopw->setuptransform();
@@ -627,6 +629,7 @@ TEST_F(PotentialNewTest, InterpolateVrsWarningQuit)
627629
rhopw->collect_local_pw();
628630
rhodpw->gamma_only = false;
629631

632+
rhodpw->set_mpi(false);
630633
rhodpw->initgrids(4, ModuleBase::Matrix3(1, 0, 0, 0, 1, 0, 0, 0, 1), 6);
631634
rhodpw->initparameters(false, 6);
632635
static_cast<ModulePW::PW_Basis_Sup*>(rhodpw)->setuptransform(rhopw);

0 commit comments

Comments
 (0)