1- #include < iostream>
21#include " module_hsolver/diag_hs_para.h"
2+
3+ #include " module_base/scalapack_connector.h"
34#include " module_basis/module_ao/parallel_2d.h"
45#include " module_hsolver/diago_pxxxgvx.h"
5- #include " module_base/scalapack_connector.h"
66#include " module_hsolver/genelpa/elpa_solver.h"
77
8+ #include < iostream>
9+
810namespace hsolver
911{
1012
11- #ifdef __ELPA
12- void elpa_diag (MPI_Comm comm,
13- const int nband,
14- std::complex <double >* h_local,
15- std::complex <double >* s_local,
16- double * ekb,
17- std::complex <double >* wfc_2d,
18- Parallel_2D& para2d_local)
19- {
20- int DecomposedState = 0 ;
21- ELPA_Solver es (false ,
22- comm,
23- nband,
24- para2d_local.get_row_size (),
25- para2d_local.get_col_size (),
26- para2d_local.desc );
27- es.generalized_eigenvector (h_local, s_local, DecomposedState, ekb, wfc_2d);
28- es.exit ();
29- }
13+ #ifdef __ELPA
14+ void elpa_diag (MPI_Comm comm,
15+ const int nband,
16+ std::complex <double >* h_local,
17+ std::complex <double >* s_local,
18+ double * ekb,
19+ std::complex <double >* wfc_2d,
20+ Parallel_2D& para2d_local)
21+ {
22+ int DecomposedState = 0 ;
23+ ELPA_Solver es (false , comm, nband, para2d_local.get_row_size (), para2d_local.get_col_size (), para2d_local.desc );
24+ es.generalized_eigenvector (h_local, s_local, DecomposedState, ekb, wfc_2d);
25+ es.exit ();
26+ }
3027
31- void elpa_diag (MPI_Comm comm,
32- const int nband,
33- double * h_local,
34- double * s_local,
35- double * ekb,
36- double * wfc_2d,
37- Parallel_2D& para2d_local)
38- {
39- int DecomposedState = 0 ;
40- ELPA_Solver es (true ,
41- comm,
42- nband,
43- para2d_local.get_row_size (),
44- para2d_local.get_col_size (),
45- para2d_local.desc );
46- es.generalized_eigenvector (h_local, s_local, DecomposedState, ekb, wfc_2d);
47- es.exit ();
48- }
28+ void elpa_diag (MPI_Comm comm,
29+ const int nband,
30+ double * h_local,
31+ double * s_local,
32+ double * ekb,
33+ double * wfc_2d,
34+ Parallel_2D& para2d_local)
35+ {
36+ int DecomposedState = 0 ;
37+ ELPA_Solver es (true , comm, nband, para2d_local.get_row_size (), para2d_local.get_col_size (), para2d_local.desc );
38+ es.generalized_eigenvector (h_local, s_local, DecomposedState, ekb, wfc_2d);
39+ es.exit ();
40+ }
4941
50- void elpa_diag (MPI_Comm comm,
51- const int nband,
52- std::complex <float >* h_local,
53- std::complex <float >* s_local,
54- float * ekb,
55- std::complex <float >* wfc_2d,
56- Parallel_2D& para2d_local)
57- {
58- std::cout << " Error: ELPA do not support single precision. " << std::endl;
59- exit (1 );
60- }
42+ void elpa_diag (MPI_Comm comm,
43+ const int nband,
44+ std::complex <float >* h_local,
45+ std::complex <float >* s_local,
46+ float * ekb,
47+ std::complex <float >* wfc_2d,
48+ Parallel_2D& para2d_local)
49+ {
50+ std::cout << " Error: ELPA do not support single precision. " << std::endl;
51+ exit (1 );
52+ }
6153
62- void elpa_diag (MPI_Comm comm,
63- const int nband,
64- float * h_local,
65- float * s_local,
66- float * ekb,
67- float * wfc_2d,
68- Parallel_2D& para2d_local)
69- {
70- std::cout << " Error: ELPA do not support single precision. " << std::endl;
71- exit (1 );
72- }
54+ void elpa_diag (MPI_Comm comm,
55+ const int nband,
56+ float * h_local,
57+ float * s_local,
58+ float * ekb,
59+ float * wfc_2d,
60+ Parallel_2D& para2d_local)
61+ {
62+ std::cout << " Error: ELPA do not support single precision. " << std::endl;
63+ exit (1 );
64+ }
7365
7466#endif
7567
76-
7768#ifdef __MPI
7869
7970template <typename T>
80- void Diago_HS_para (
81- T* h,
82- T* s,
83- const int lda,
84- const int nband,
85- typename GetTypeReal<T>::type *const ekb,
86- T *const wfc,
87- const MPI_Comm& comm,
88- const int diag_subspace_method,
89- const int block_size)
71+ void Diago_HS_para (T* h,
72+ T* s,
73+ const int lda,
74+ const int nband,
75+ typename GetTypeReal<T>::type* const ekb,
76+ T* const wfc,
77+ const MPI_Comm& comm,
78+ const int diag_subspace,
79+ const int block_size)
9080{
91- int myrank;
81+ int myrank = 0 ;
9282 MPI_Comm_rank (comm, &myrank);
9383 Parallel_2D para2d_global;
9484 Parallel_2D para2d_local;
95- para2d_global.init (lda,lda,lda,comm);
85+ para2d_global.init (lda, lda, lda, comm);
9686
9787 int max_nb = block_size;
9888 if (block_size == 0 )
@@ -113,88 +103,99 @@ void Diago_HS_para(
113103 }
114104
115105 // for genelpa, if the block size is too large that some cores have no data, then it will cause error.
116- if (diag_subspace_method == 1 )
106+ if (diag_subspace == 1 )
117107 {
118108 if (max_nb * (std::max (para2d_global.dim0 , para2d_global.dim1 ) - 1 ) >= lda)
119109 {
120110 max_nb = lda / std::max (para2d_global.dim0 , para2d_global.dim1 );
121111 }
122112 }
123-
124- para2d_local.init (lda,lda,max_nb,comm);
113+
114+ para2d_local.init (lda, lda, max_nb, comm);
125115 std::vector<T> h_local (para2d_local.get_col_size () * para2d_local.get_row_size ());
126116 std::vector<T> s_local (para2d_local.get_col_size () * para2d_local.get_row_size ());
127117 std::vector<T> wfc_2d (para2d_local.get_col_size () * para2d_local.get_row_size ());
128-
118+
129119 // distribute h and s to 2D
130- Cpxgemr2d (lda,lda,h, 1 , 1 , para2d_global.desc ,h_local.data (),1 , 1 , para2d_local.desc ,para2d_local.blacs_ctxt );
131- Cpxgemr2d (lda,lda,s, 1 , 1 , para2d_global.desc ,s_local.data (),1 , 1 , para2d_local.desc ,para2d_local.blacs_ctxt );
120+ Cpxgemr2d (lda, lda, h, 1 , 1 , para2d_global.desc , h_local.data (), 1 , 1 , para2d_local.desc , para2d_local.blacs_ctxt );
121+ Cpxgemr2d (lda, lda, s, 1 , 1 , para2d_global.desc , s_local.data (), 1 , 1 , para2d_local.desc , para2d_local.blacs_ctxt );
132122
133- if (diag_subspace_method == 1 )
123+ if (diag_subspace == 1 )
134124 {
135- #ifdef __ELPA
125+ #ifdef __ELPA
136126 elpa_diag (comm, nband, h_local.data (), s_local.data (), ekb, wfc_2d.data (), para2d_local);
137127#else
138- std::cout << " ERROR: try to use ELPA to solve the generalized eigenvalue problem, but ELPA is not compiled. " << std::endl;
128+ std::cout << " ERROR: try to use ELPA to solve the generalized eigenvalue problem, but ELPA is not compiled. "
129+ << std::endl;
139130 exit (1 );
140- #endif
131+ #endif
141132 }
142- else if (diag_subspace_method == 2 )
133+ else if (diag_subspace == 2 )
143134 {
144- hsolver::pxxxgvx_diag (para2d_local.desc , para2d_local.get_row_size (), para2d_local.get_col_size (),nband, h_local.data (), s_local.data (), ekb, wfc_2d.data ());
135+ hsolver::pxxxgvx_diag (para2d_local.desc ,
136+ para2d_local.get_row_size (),
137+ para2d_local.get_col_size (),
138+ nband,
139+ h_local.data (),
140+ s_local.data (),
141+ ekb,
142+ wfc_2d.data ());
145143 }
146- else {
147- std::cout << " Error: parallel diagonalization method is not supported. " << " diag_subspace_method = " << diag_subspace_method << std::endl;
144+ else
145+ {
146+ std::cout << " Error: parallel diagonalization method is not supported. " << " diag_subspace = " << diag_subspace
147+ << std::endl;
148148 exit (1 );
149149 }
150150
151151 // gather wfc
152- Cpxgemr2d (lda,lda,wfc_2d.data (),1 , 1 , para2d_local.desc ,wfc,1 , 1 , para2d_global.desc ,para2d_local.blacs_ctxt );
152+ Cpxgemr2d (lda, lda, wfc_2d.data (), 1 , 1 , para2d_local.desc , wfc, 1 , 1 , para2d_global.desc , para2d_local.blacs_ctxt );
153153
154154 // free the context
155155 Cblacs_gridexit (para2d_local.blacs_ctxt );
156156 Cblacs_gridexit (para2d_global.blacs_ctxt );
157157}
158158
159159// template instantiation
160- template void Diago_HS_para<double >(double * h,
161- double * s,
162- const int lda,
163- const int nband,
164- typename GetTypeReal<double >::type *const ekb,
165- double *const wfc,
166- const MPI_Comm& comm,
167- const int diag_subspace_method,
168- const int block_size);
169- template void Diago_HS_para<std::complex <double >>(std::complex <double >* h,
170- std::complex <double >* s,
171- const int lda,
172- const int nband,
173- typename GetTypeReal<std::complex <double >>::type *const ekb,
174- std::complex <double > *const wfc,
175- const MPI_Comm& comm,
176- const int diag_subspace_method,
177- const int block_size);
178- template void Diago_HS_para<float >(float * h,
179- float * s,
180- const int lda,
181- const int nband,
182- typename GetTypeReal<float >::type *const ekb,
183- float *const wfc,
184- const MPI_Comm& comm,
185- const int diag_subspace_method,
160+ template void Diago_HS_para<double >(double * h,
161+ double * s,
162+ const int lda,
163+ const int nband,
164+ typename GetTypeReal<double >::type* const ekb,
165+ double * const wfc,
166+ const MPI_Comm& comm,
167+ const int diag_subspace,
186168 const int block_size);
187- template void Diago_HS_para<std::complex <float >>(std::complex <float >* h,
188- std::complex <float >* s,
189- const int lda,
190- const int nband,
191- typename GetTypeReal<std::complex <float >>::type *const ekb,
192- std::complex <float > *const wfc,
193- const MPI_Comm& comm,
194- const int diag_subspace_method,
195- const int block_size);
196-
197169
170+ template void Diago_HS_para<std::complex <double >>(std::complex <double >* h,
171+ std::complex <double >* s,
172+ const int lda,
173+ const int nband,
174+ typename GetTypeReal<std::complex <double >>::type* const ekb,
175+ std::complex <double >* const wfc,
176+ const MPI_Comm& comm,
177+ const int diag_subspace,
178+ const int block_size);
179+
180+ template void Diago_HS_para<float >(float * h,
181+ float * s,
182+ const int lda,
183+ const int nband,
184+ typename GetTypeReal<float >::type* const ekb,
185+ float * const wfc,
186+ const MPI_Comm& comm,
187+ const int diag_subspace,
188+ const int block_size);
189+
190+ template void Diago_HS_para<std::complex <float >>(std::complex <float >* h,
191+ std::complex <float >* s,
192+ const int lda,
193+ const int nband,
194+ typename GetTypeReal<std::complex <float >>::type* const ekb,
195+ std::complex <float >* const wfc,
196+ const MPI_Comm& comm,
197+ const int diag_subspace,
198+ const int block_size);
198199
199200#endif
200201
0 commit comments