11#include " ../pw_basis.h"
22#include " ../../src_parallel/parallel_global.h"
3+ #ifdef __MPI
34#include " test_tool.h"
5+ #include " mpi.h"
6+ #endif
47#include " ../../module_base/timer.h"
58#include " ../../module_base/global_function.h"
69
@@ -13,8 +16,13 @@ int main(int argc,char **argv)
1316 int nproc, myrank;
1417 int nproc_in_pool, npool, mypool, rank_in_pool;
1518 npool = 1 ;
19+ #ifdef __MPI
1620 setupmpi (argc,argv,nproc, myrank);
1721 divide_pools (nproc, myrank, nproc_in_pool, npool, mypool, rank_in_pool);
22+ #else
23+ nproc = nproc_in_pool = npool = 1 ;
24+ myrank = mypool = rank_in_pool = 0 ;
25+ #endif
1826
1927 ModuleBase::timer::start ();
2028
@@ -25,14 +33,19 @@ int main(int argc,char **argv)
2533 pwtest.initparameters (gamma_only, ecut, nproc, rank_in_pool, distribution_type);
2634 pwtest.distribute_r ();
2735 pwtest.distribute_g ();
28- MPI_Barrier (POOL_WORLD);
2936
3037 int tot_npw = 0 ;
3138 int nxy = pwtest.nx * pwtest.ny ;
39+ #ifdef __MPI
3240 MPI_Reduce (&pwtest.npw , &tot_npw, 1 , MPI_INT, MPI_SUM, 0 , POOL_WORLD);
41+ #else
42+ tot_npw = pwtest.npw ;
43+ #endif
3344 for (int ip = 0 ; ip < nproc; ip++)
3445 {
46+ #ifdef __MPI
3547 MPI_Barrier (POOL_WORLD);
48+ #endif
3649 if (rank_in_pool == ip)
3750 {
3851 std::cout<<" ip: " <<ip<<' \n ' ;
@@ -58,11 +71,15 @@ int main(int argc,char **argv)
5871 std::cout << " \n " ;
5972 }
6073 }
74+ #ifdef __MPI
6175 MPI_Barrier (POOL_WORLD);
76+ #endif
6277 pwtest.collect_local_pw ();
6378 for (int ip = 0 ; ip < nproc; ip++)
6479 {
80+ #ifdef __MPI
6581 MPI_Barrier (POOL_WORLD);
82+ #endif
6683 if (rank_in_pool == ip)
6784 {
6885 std::cout<<" ip: " <<ip<<' \n ' ;
@@ -80,23 +97,5 @@ int main(int argc,char **argv)
8097 }
8198 if (rank_in_pool==0 ) ModuleBase::timer::finish (GlobalV::ofs_running, true );
8299
83- // if (rank_in_pool == 0)
84- // {
85- // std::cout << "tot_npw " << tot_npw << "\n";
86- // double* gg_global = new double[tot_npw];
87- // ModuleBase::Vector3<double> *gdirect_global = new ModuleBase::Vector3<double>[tot_npw];
88- // ModuleBase::Vector3<double> *gcar_global = new ModuleBase::Vector3<double>[tot_npw];
89- // pwtest.collect_tot_pw(gg_global, gdirect_global, gcar_global);
90- // std::cout<<"gg_global gdirect_global gcar_global\n";
91- // for (int ig = 0; ig < tot_npw; ++ig)
92- // {
93- // std::cout << gg_global[ig] << std::setw(4) << gdirect_global[ig] << std::setw(4) << gcar_global[ig];
94- // std::cout << "\n";
95- // }
96- // std::cout<<"done"<<"\n";
97- // delete[] gg_global;
98- // delete[] gdirect_global;
99- // delete[] gcar_global;
100- // }
101100 return 0 ;
102101}
0 commit comments