88#include " ../diago_iter_assist.h"
99#include " diago_mock.h"
1010#include " mpi.h"
11+ #include " module_pw/unittest/test_tool.h"
1112#include < complex>
1213
1314#include " gtest/gtest.h"
@@ -57,21 +58,26 @@ class DiagoCGPrepare
5758 : nband(nband), npw(npw), sparsity(sparsity), reorder(reorder), eps(eps), maxiter(maxiter),
5859 threshold (threshold)
5960 {
61+ #ifdef __MPI
62+ MPI_Comm_size (MPI_COMM_WORLD, &nprocs);
63+ MPI_Comm_rank (MPI_COMM_WORLD, &mypnum);
64+ #endif
6065 }
6166
6267 int nband, npw, sparsity, maxiter, notconv;
6368 // eps is the convergence threshold within cg_diago
6469 double eps, avg_iter;
6570 bool reorder;
6671 double threshold;
72+ int nprocs=1 , mypnum=0 ;
6773 // threshold is the comparison standard between cg and lapack
6874
6975 void CompareEigen (double *precondition)
7076 {
7177 // calculate eigenvalues by LAPACK;
7278 double *e_lapack = new double [npw];
7379 ModuleBase::ComplexMatrix ev = DIAGOTEST::hmatrix;
74- lapackEigen (npw, ev, e_lapack, false );
80+ if (mypnum == 0 ) lapackEigen (npw, ev, e_lapack, false );
7581 // initial guess of psi by perturbing lapack psi
7682 ModuleBase::ComplexMatrix psiguess (nband, npw);
7783 std::default_random_engine p (1 );
@@ -80,7 +86,7 @@ class DiagoCGPrepare
8086 {
8187 for (int j = 0 ; j < npw; j++)
8288 {
83- double rand = static_cast <double >(u (p))/10 .;
89+ double rand = static_cast <double >(u (p))/10 .;
8490 // psiguess(i,j) = ev(j,i)*(1+rand);
8591 psiguess (i, j) = ev (j, i) * rand;
8692 }
@@ -89,24 +95,47 @@ class DiagoCGPrepare
8995 // ======================================================================
9096 double *en = new double [npw];
9197 int ik = 1 ;
92- hamilt::Hamilt* ha;
93- ha =new hamilt::HamiltPW;
94- Hamilt_PW* hpw;
95- int * ngk = new int [1 ];
96- // psi::Psi<std::complex<double>> psi(ngk,ik,nband,npw);
97- psi::Psi<std::complex <double >> psi;
98- psi.resize (ik,nband,npw);
99- // psi.fix_k(0);
98+ hamilt::Hamilt* ha;
99+ ha =new hamilt::HamiltPW;
100+ Hamilt_PW* hpw;
101+ int * ngk = new int [1 ];
102+ // psi::Psi<std::complex<double>> psi(ngk,ik,nband,npw);
103+ psi::Psi<std::complex <double >> psi;
104+ psi.resize (ik,nband,npw);
105+ // psi.fix_k(0);
100106 for (int i = 0 ; i < nband; i++)
101107 {
102108 for (int j = 0 ; j < npw; j++)
103109 {
104- psi (i,j)=psiguess (i,j);
105- }
106- }
107- hsolver::DiagoCG cg (hpw,precondition);
108- cg.diag (ha,psi,en);
109- // ======================================================================
110+ psi (i,j)=psiguess (i,j);
111+ }
112+ }
113+
114+ psi::Psi<std::complex <double >> psi_local;
115+ double * precondition_local;
116+ DIAGOTEST::npw_local = new int [nprocs];
117+ #ifdef __MPI
118+ DIAGOTEST::cal_division (DIAGOTEST::npw);
119+ DIAGOTEST::divide_hpsi (psi,psi_local); // will distribute psi and Hmatrix to each process
120+ precondition_local = new double [DIAGOTEST::npw_local[mypnum]];
121+ DIAGOTEST::divide_psi<double >(precondition,precondition_local);
122+ #else
123+ DIAGOTEST::hmatrix_local = DIAGOTEST::hmatrix;
124+ DIAGOTEST::npw_local[0 ] = DIAGOTEST::npw;
125+ psi_local = psi;
126+ precondition_local = new double [DIAGOTEST::npw];
127+ for (int i=0 ;i<DIAGOTEST::npw;i++) precondition_local[i] = precondition[i];
128+ #endif
129+ hsolver::DiagoCG cg (hpw,precondition_local);
130+ psi_local.fix_k (0 );
131+ double start, end;
132+ start = MPI_Wtime ();
133+ cg.diag (ha,psi_local,en);
134+ end = MPI_Wtime ();
135+ // if(mypnum == 0) printf("diago time:%7.3f\n",end-start);
136+ delete [] DIAGOTEST::npw_local;
137+ delete [] precondition_local;
138+ // ======================================================================
110139 for (int i = 0 ; i < nband; i++)
111140 {
112141 EXPECT_NEAR (en[i], e_lapack[i], threshold);
@@ -142,11 +171,11 @@ INSTANTIATE_TEST_SUITE_P(VerifyCG,
142171 DiagoCGTest,
143172 ::testing::Values (
144173 // nband, npw, sparsity, reorder, eps, maxiter, threshold
145- DiagoCGPrepare (10 , 500 , 0 , true , 1e-5 , 100 , 1e-3 ),
174+ DiagoCGPrepare (10 , 500 , 0 , true , 1e-5 , 300 , 1e-3 ),
146175 DiagoCGPrepare(20 , 500 , 6 , true , 1e-5 , 300 , 1e-3 ),
147176 DiagoCGPrepare(20 , 1000 , 8 , true , 1e-5 , 300 , 1e-3 ),
148177 DiagoCGPrepare(40 , 1000 , 8 , true , 1e-6 , 300 , 1e-3 )));
149- // DiagoCGPrepare(40, 2000, 8, true, 1e-5, 500, 1e-2)));
178+ // DiagoCGPrepare(40, 2000, 8, true, 1e-5, 500, 1e-2)));
150179 // the last one is passed but time-consumming.
151180
152181// check that the mock class HPsi work well
@@ -187,6 +216,8 @@ TEST(DiagoCGTest, ZHEEV)
187216}
188217
189218// cg for a 2x2 matrix
219+ #ifdef __MPI
220+ #else
190221TEST (DiagoCGTest, TwoByTwo)
191222{
192223 int dim = 2 ;
@@ -206,19 +237,20 @@ TEST(DiagoCGTest, TwoByTwo)
206237 DIAGOTEST::npw = dim;
207238 dcp.CompareEigen (hpsi.precond ());
208239}
240+ #endif
209241
210242TEST (DiagoCGTest, readH)
211243{
212244 // read Hamilt matrix from file data-H
213245 ModuleBase::ComplexMatrix hm;
214246 std::ifstream ifs;
215- ifs.open (" data-H " );
247+ ifs.open (" H-KPoints-Si64.dat " );
216248 DIAGOTEST::readh (ifs, hm);
217249 ifs.close ();
218250 int dim = hm.nr ;
219251 int nband = 10 ; // not nband < dim, here dim = 26 in data-H
220252 // nband, npw, sub, sparsity, reorder, eps, maxiter, threshold
221- DiagoCGPrepare dcp (nband, dim, 0 , true , 1e-4 , 300 , 1e-3 );
253+ DiagoCGPrepare dcp (nband, dim, 0 , true , 1e-5 , 500 , 1e-3 );
222254 hsolver::DiagoIterAssist::PW_DIAG_NMAX = dcp.maxiter ;
223255 hsolver::DiagoIterAssist::PW_DIAG_THR = dcp.eps ;
224256 HPsi hpsi;
@@ -230,13 +262,28 @@ TEST(DiagoCGTest, readH)
230262
231263int main (int argc, char **argv)
232264{
265+ int nproc = 1 , myrank = 0 ;
233266
234- MPI_Init (&argc, &argv);
267+ #ifdef __MPI
268+ int nproc_in_pool, kpar=1 , mypool, rank_in_pool;
269+ setupmpi (argc,argv,nproc, myrank);
270+ divide_pools (nproc, myrank, nproc_in_pool, kpar, mypool, rank_in_pool);
271+ GlobalV::NPROC_IN_POOL = nproc;
272+ #else
273+ MPI_Init (&argc, &argv);
274+ #endif
235275
236276 testing::InitGoogleTest (&argc, argv);
277+ ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance ()->listeners ();
278+ if (myrank != 0 ) delete listeners.Release (listeners.default_result_printer ());
279+
237280 int result = RUN_ALL_TESTS ();
281+ if (myrank == 0 && result != 0 )
282+ {
283+ std::cout << " ERROR:some tests are not passed" << std::endl;
284+ return result;
285+ }
238286
239287 MPI_Finalize ();
240-
241- return result;
288+ return 0 ;
242289}
0 commit comments