Skip to content

Commit 200a9cf

Browse files
author
root
committed
feature: parallel solve subspace diagonalization in dav_subspace
1 parent 0ecfbc4 commit 200a9cf

File tree

23 files changed

+1202
-48
lines changed

23 files changed

+1202
-48
lines changed

docs/advanced/input_files/input-main.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
- [pw\_diag\_thr](#pw_diag_thr)
4040
- [pw\_diag\_nmax](#pw_diag_nmax)
4141
- [pw\_diag\_ndim](#pw_diag_ndim)
42+
- [diag\_subspace\_method](#diag_subspace_method)
4243
- [erf\_ecut](#erf_ecut)
4344
- [fft\_mode](#fft_mode)
4445
- [erf\_height](#erf_height)
@@ -783,7 +784,18 @@ These variables are used to control the plane wave related parameters.
783784

784785
- **Type**: Integer
785786
- **Description**: Only useful when you use `ks_solver = dav` or `ks_solver = dav_subspace`. It indicates dimension of workspace(number of wavefunction packets, at least 2 needed) for the Davidson method. A larger value may yield a smaller number of iterations in the algorithm but uses more memory and more CPU time in subspace diagonalization.
786-
- **Default**: 4
787+
- **Default**: 4
788+
789+
### diag_subspace_method
790+
791+
- **Type**: Integer
792+
- **Description**: The method to diagonalize subspace in dav_subspace method. The available options are:
793+
- 0: by LAPACK
794+
- 1: by GenELPA
795+
- 2: by ScaLAPACK
796+
LAPACK only solve in one core, GenELPA and ScaLAPACK can solve in parallel. If the system is small (such as the band number is less than 100), LAPACK is recommended. If the system is large and MPI parallel is used, then GenELPA or ScaLAPACK is recommended, and GenELPA usually has better performance. For GenELPA and ScaLAPACK, the block size can be set by [nb2d](#nb2d).
797+
798+
- **Default**: 0
787799

788800
### erf_ecut
789801

python/pyabacus/src/hsolver/py_diago_dav_subspace.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ class PyDiagoDavSubspace
138138
tol,
139139
max_iter,
140140
need_subspace,
141-
comm_info
141+
comm_info,
142+
PARAM.inp.diag_subspace_method,
143+
PARAM.inp.nb2d
142144
);
143145

144146
return obj->diag(hpsi_func, psi, nbasis, eigenvalue, diag_ethr.data(), scf_type);

source/module_base/blacs_connector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ extern "C"
3939
// Informational and Miscellaneous
4040
void Cblacs_gridinfo(int icontxt, int* nprow, int *npcol, int *myprow, int *mypcol);
4141
void Cblacs_gridinit(int* icontxt, char* layout, int nprow, int npcol);
42-
void Cblacs_gridexit(int* icontxt);
42+
void Cblacs_gridexit(int icontxt);
4343
int Cblacs_pnum(int icontxt, int prow, int pcol);
4444
void Cblacs_pcoord(int icontxt, int pnum, int *prow, int *pcol);
4545
void Cblacs_exit(int icontxt);

source/module_base/scalapack_connector.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,17 @@ extern "C"
8585
const double* vl, const double* vu, const int* il, const int* iu,
8686
const double* abstol, int* m, int* nz, double* w, const double*orfac, std::complex<double>* Z, const int* iz, const int* jz, const int*descz,
8787
std::complex<double>* work, int* lwork, double* rwork, int* lrwork, int*iwork, int*liwork, int* ifail, int*iclustr, double*gap, int* info);
88+
void pssygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
89+
const int* n, float* A, const int* ia, const int* ja, const int*desca, float* B, const int* ib, const int* jb, const int*descb,
90+
const float* vl, const float* vu, const int* il, const int* iu,
91+
const float* abstol, int* m, int* nz, float* w, const float*orfac, float* Z, const int* iz, const int* jz, const int*descz,
92+
float* work, int* lwork, int*iwork, int*liwork, int* ifail, int*iclustr, float*gap, int* info);
93+
void pchegvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
94+
const int* n, std::complex<float>* A, const int* ia, const int* ja, const int*desca, std::complex<float>* B, const int* ib, const int* jb, const int*descb,
95+
const float* vl, const float* vu, const int* il, const int* iu,
96+
const float* abstol, int* m, int* nz, float* w, const float*orfac, std::complex<float>* Z, const int* iz, const int* jz, const int*descz,
97+
std::complex<float>* work, int* lwork, float* rwork, int* lrwork, int*iwork, int*liwork, int* ifail, int*iclustr, float*gap, int* info);
98+
8899

89100
void pzgetri_(
90101
const int *n,

source/module_hsolver/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ list(APPEND objects
99
hsolver_pw_sdft.cpp
1010
diago_iter_assist.cpp
1111
hsolver.cpp
12+
diago_pxxxgvx.cpp
13+
diag_hs_para.cpp
14+
1215
)
1316

1417
if(ENABLE_LCAO)
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#include <iostream>
2+
#include "module_hsolver/diag_hs_para.h"
3+
#include "module_basis/module_ao/parallel_2d.h"
4+
#include "module_hsolver/diago_pxxxgvx.h"
5+
#include "module_base/scalapack_connector.h"
6+
#include "module_hsolver/genelpa/elpa_solver.h"
7+
8+
namespace hsolver
9+
{
10+
11+
#ifdef __ELPA
12+
void elpa_diag(MPI_Comm comm,
13+
const int nband,
14+
std::complex<double>* h_local,
15+
std::complex<double>* s_local,
16+
double* ekb,
17+
std::complex<double>* wfc_2d,
18+
Parallel_2D& para2d_local)
19+
{
20+
int DecomposedState = 0;
21+
ELPA_Solver es(false,
22+
comm,
23+
nband,
24+
para2d_local.get_row_size(),
25+
para2d_local.get_col_size(),
26+
para2d_local.desc);
27+
es.generalized_eigenvector(h_local, s_local, DecomposedState, ekb, wfc_2d);
28+
es.exit();
29+
}
30+
31+
void elpa_diag(MPI_Comm comm,
32+
const int nband,
33+
double* h_local,
34+
double* s_local,
35+
double* ekb,
36+
double* wfc_2d,
37+
Parallel_2D& para2d_local)
38+
{
39+
int DecomposedState = 0;
40+
ELPA_Solver es(true,
41+
comm,
42+
nband,
43+
para2d_local.get_row_size(),
44+
para2d_local.get_col_size(),
45+
para2d_local.desc);
46+
es.generalized_eigenvector(h_local, s_local, DecomposedState, ekb, wfc_2d);
47+
es.exit();
48+
}
49+
50+
void elpa_diag(MPI_Comm comm,
51+
const int nband,
52+
std::complex<float>* h_local,
53+
std::complex<float>* s_local,
54+
float* ekb,
55+
std::complex<float>* wfc_2d,
56+
Parallel_2D& para2d_local)
57+
{
58+
std::cout << "Error: ELPA do not support single precision. " << std::endl;
59+
exit(1);
60+
}
61+
62+
void elpa_diag(MPI_Comm comm,
63+
const int nband,
64+
float* h_local,
65+
float* s_local,
66+
float* ekb,
67+
float* wfc_2d,
68+
Parallel_2D& para2d_local)
69+
{
70+
std::cout << "Error: ELPA do not support single precision. " << std::endl;
71+
exit(1);
72+
}
73+
74+
#endif
75+
76+
77+
#ifdef __MPI
78+
79+
template <typename T>
80+
void Diago_HS_para(
81+
T* h,
82+
T* s,
83+
const int lda,
84+
const int nband,
85+
typename GetTypeReal<T>::type *const ekb,
86+
T *const wfc,
87+
const MPI_Comm& comm,
88+
const int diag_subspace_method,
89+
const int block_size)
90+
{
91+
int myrank;
92+
MPI_Comm_rank(comm, &myrank);
93+
Parallel_2D para2d_global;
94+
Parallel_2D para2d_local;
95+
para2d_global.init(lda,lda,lda,comm);
96+
97+
int max_nb = block_size;
98+
if (block_size == 0)
99+
{
100+
if (nband > 500)
101+
{
102+
max_nb = 32;
103+
}
104+
else
105+
{
106+
max_nb = 16;
107+
}
108+
}
109+
else if (block_size < 0)
110+
{
111+
std::cout << "Error: block_size in diago_subspace should be a positive integer. " << std::endl;
112+
exit(1);
113+
}
114+
115+
// for genelpa, if the block size is too large that some cores have no data, then it will cause error.
116+
if (diag_subspace_method == 1)
117+
{
118+
if (max_nb * (std::max(para2d_global.dim0, para2d_global.dim1) - 1) >= lda)
119+
{
120+
max_nb = lda / std::max(para2d_global.dim0, para2d_global.dim1);
121+
}
122+
}
123+
124+
para2d_local.init(lda,lda,max_nb,comm);
125+
std::vector<T> h_local(para2d_local.get_col_size() * para2d_local.get_row_size());
126+
std::vector<T> s_local(para2d_local.get_col_size() * para2d_local.get_row_size());
127+
std::vector<T> wfc_2d(para2d_local.get_col_size() * para2d_local.get_row_size());
128+
129+
// distribute h and s to 2D
130+
Cpxgemr2d(lda,lda,h,1,1,para2d_global.desc,h_local.data(),1,1,para2d_local.desc,para2d_local.blacs_ctxt);
131+
Cpxgemr2d(lda,lda,s,1,1,para2d_global.desc,s_local.data(),1,1,para2d_local.desc,para2d_local.blacs_ctxt);
132+
133+
if (diag_subspace_method == 1)
134+
{
135+
#ifdef __ELPA
136+
elpa_diag(comm, nband, h_local.data(), s_local.data(), ekb, wfc_2d.data(), para2d_local);
137+
#else
138+
std::cout << "ERROR: try to use ELPA to solve the generalized eigenvalue problem, but ELPA is not compiled. " << std::endl;
139+
exit(1);
140+
#endif
141+
}
142+
else if (diag_subspace_method == 2)
143+
{
144+
hsolver::pxxxgvx_diag(para2d_local.desc, para2d_local.get_row_size(), para2d_local.get_col_size(),nband, h_local.data(), s_local.data(), ekb, wfc_2d.data());
145+
}
146+
else{
147+
std::cout << "Error: parallel diagonalization method is not supported. " << "diag_subspace_method = " << diag_subspace_method << std::endl;
148+
exit(1);
149+
}
150+
151+
// gather wfc
152+
Cpxgemr2d(lda,lda,wfc_2d.data(),1,1,para2d_local.desc,wfc,1,1,para2d_global.desc,para2d_local.blacs_ctxt);
153+
154+
// free the context
155+
Cblacs_gridexit(para2d_local.blacs_ctxt);
156+
Cblacs_gridexit(para2d_global.blacs_ctxt);
157+
}
158+
159+
// template instantiation
160+
template void Diago_HS_para<double>(double* h,
161+
double* s,
162+
const int lda,
163+
const int nband,
164+
typename GetTypeReal<double>::type *const ekb,
165+
double *const wfc,
166+
const MPI_Comm& comm,
167+
const int diag_subspace_method,
168+
const int block_size);
169+
template void Diago_HS_para<std::complex<double>>(std::complex<double>* h,
170+
std::complex<double>* s,
171+
const int lda,
172+
const int nband,
173+
typename GetTypeReal<std::complex<double>>::type *const ekb,
174+
std::complex<double> *const wfc,
175+
const MPI_Comm& comm,
176+
const int diag_subspace_method,
177+
const int block_size);
178+
template void Diago_HS_para<float>(float* h,
179+
float* s,
180+
const int lda,
181+
const int nband,
182+
typename GetTypeReal<float>::type *const ekb,
183+
float *const wfc,
184+
const MPI_Comm& comm,
185+
const int diag_subspace_method,
186+
const int block_size);
187+
template void Diago_HS_para<std::complex<float>>(std::complex<float>* h,
188+
std::complex<float>* s,
189+
const int lda,
190+
const int nband,
191+
typename GetTypeReal<std::complex<float>>::type *const ekb,
192+
std::complex<float> *const wfc,
193+
const MPI_Comm& comm,
194+
const int diag_subspace_method,
195+
const int block_size);
196+
197+
198+
199+
#endif
200+
201+
} // namespace hsolver
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "module_basis/module_ao/parallel_2d.h"
2+
#include "module_base/macros.h"
3+
4+
#ifdef __MPI
5+
#include <mpi.h>
6+
#endif
7+
8+
namespace hsolver
9+
{
10+
11+
12+
#ifdef __MPI
13+
14+
/**
15+
* @brief Parallel do the generalized eigenvalue problem
16+
*
17+
* @tparam T double or complex<double> or float or complex<float>
18+
* @param H the hermitian matrix H.
19+
* @param S the overlap matrix S.
20+
* @param lda the leading dimension of H and S
21+
* @param nband the number of bands to be calculated
22+
* @param ekb to store the eigenvalues.
23+
* @param wfc to store the eigenvectors
24+
* @param comm the communicator
25+
* @param diag_subspace_method the method to solve the generalized eigenvalue problem
26+
* @param block_size the block size in 2d block cyclic distribution if use elpa or scalapack.
27+
*
28+
* @note 1. h and s should be full matrix in rank 0 of the communicator, and the other ranks is not concerned.
29+
* @note 2. wfc is complete in rank 0, and not store in other ranks.
30+
* @note 3. diag_subspace_method should be 1: by elpa, 2: by scalapack
31+
* @note 4. block_size should be 0 or a positive integer. If it is 0, then will use a value as large as possible that is allowed
32+
*/
33+
template <typename T>
34+
void Diago_HS_para(
35+
T* h,
36+
T* s,
37+
const int lda,
38+
const int nband,
39+
typename GetTypeReal<T>::type *const ekb,
40+
T *const wfc,
41+
const MPI_Comm& comm,
42+
const int diag_subspace_method,
43+
const int block_size=0);
44+
#endif
45+
46+
} // namespace hsolver
47+

0 commit comments

Comments
 (0)