Skip to content

Commit 2cf60e2

Browse files
author
root
committed
fix ut
1 parent 4127219 commit 2cf60e2

File tree

16 files changed

+709
-269
lines changed

16 files changed

+709
-269
lines changed

docs/advanced/input_files/input-main.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
- [pw\_diag\_thr](#pw_diag_thr)
4040
- [pw\_diag\_nmax](#pw_diag_nmax)
4141
- [pw\_diag\_ndim](#pw_diag_ndim)
42-
- [diag\_subspace\_method](#diag_subspace_method)
42+
- [diag\_subspace\_method](#diag_subspace)
4343
- [erf\_ecut](#erf_ecut)
4444
- [fft\_mode](#fft_mode)
4545
- [erf\_height](#erf_height)
@@ -787,7 +787,7 @@ These variables are used to control the plane wave related parameters.
787787
- **Description**: Only useful when you use `ks_solver = dav` or `ks_solver = dav_subspace`. It indicates dimension of workspace(number of wavefunction packets, at least 2 needed) for the Davidson method. A larger value may yield a smaller number of iterations in the algorithm but uses more memory and more CPU time in subspace diagonalization.
788788
- **Default**: 4
789789

790-
### diag_subspace_method
790+
### diag_subspace
791791

792792
- **Type**: Integer
793793
- **Description**: The method to diagonalize subspace in dav_subspace method. The available options are:

python/pyabacus/src/hsolver/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ list(APPEND _diago
55
${HSOLVER_PATH}/diago_cg.cpp
66
${HSOLVER_PATH}/diag_const_nums.cpp
77
${HSOLVER_PATH}/diago_iter_assist.cpp
8+
${HSOLVER_PATH}/diag_hs_para.cpp
9+
${HSOLVER_PATH}/diago_pxxxgvx.cpp
10+
811

912
${HSOLVER_PATH}/kernels/dngvd_op.cpp
1013
${HSOLVER_PATH}/kernels/math_kernel_op.cpp

python/pyabacus/src/hsolver/py_diago_dav_subspace.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class PyDiagoDavSubspace
139139
max_iter,
140140
need_subspace,
141141
comm_info,
142-
PARAM.inp.diag_subspace_method,
142+
PARAM.inp.diag_subspace,
143143
PARAM.inp.nb2d
144144
);
145145

source/module_base/scalapack_connector.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,19 @@ extern "C"
8080
const double* vl, const double* vu, const int* il, const int* iu,
8181
const double* abstol, int* m, int* nz, double* w, const double*orfac, double* Z, const int* iz, const int* jz, const int*descz,
8282
double* work, int* lwork, int*iwork, int*liwork, int* ifail, int*iclustr, double*gap, int* info);
83+
8384
void pzhegvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
8485
const int* n, std::complex<double>* A, const int* ia, const int* ja, const int*desca, std::complex<double>* B, const int* ib, const int* jb, const int*descb,
8586
const double* vl, const double* vu, const int* il, const int* iu,
8687
const double* abstol, int* m, int* nz, double* w, const double*orfac, std::complex<double>* Z, const int* iz, const int* jz, const int*descz,
8788
std::complex<double>* work, int* lwork, double* rwork, int* lrwork, int*iwork, int*liwork, int* ifail, int*iclustr, double*gap, int* info);
89+
8890
void pssygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
8991
const int* n, float* A, const int* ia, const int* ja, const int*desca, float* B, const int* ib, const int* jb, const int*descb,
9092
const float* vl, const float* vu, const int* il, const int* iu,
9193
const float* abstol, int* m, int* nz, float* w, const float*orfac, float* Z, const int* iz, const int* jz, const int*descz,
9294
float* work, int* lwork, int*iwork, int*liwork, int* ifail, int*iclustr, float*gap, int* info);
95+
9396
void pchegvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
9497
const int* n, std::complex<float>* A, const int* ia, const int* ja, const int*desca, std::complex<float>* B, const int* ib, const int* jb, const int*descb,
9598
const float* vl, const float* vu, const int* il, const int* iu,

source/module_hsolver/diag_hs_para.cpp

Lines changed: 126 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,88 @@
1-
#include <iostream>
21
#include "module_hsolver/diag_hs_para.h"
2+
3+
#include "module_base/scalapack_connector.h"
34
#include "module_basis/module_ao/parallel_2d.h"
45
#include "module_hsolver/diago_pxxxgvx.h"
5-
#include "module_base/scalapack_connector.h"
66
#include "module_hsolver/genelpa/elpa_solver.h"
77

8+
#include <iostream>
9+
810
namespace hsolver
911
{
1012

11-
#ifdef __ELPA
12-
void elpa_diag(MPI_Comm comm,
13-
const int nband,
14-
std::complex<double>* h_local,
15-
std::complex<double>* s_local,
16-
double* ekb,
17-
std::complex<double>* wfc_2d,
18-
Parallel_2D& para2d_local)
19-
{
20-
int DecomposedState = 0;
21-
ELPA_Solver es(false,
22-
comm,
23-
nband,
24-
para2d_local.get_row_size(),
25-
para2d_local.get_col_size(),
26-
para2d_local.desc);
27-
es.generalized_eigenvector(h_local, s_local, DecomposedState, ekb, wfc_2d);
28-
es.exit();
29-
}
13+
#ifdef __ELPA
14+
void elpa_diag(MPI_Comm comm,
15+
const int nband,
16+
std::complex<double>* h_local,
17+
std::complex<double>* s_local,
18+
double* ekb,
19+
std::complex<double>* wfc_2d,
20+
Parallel_2D& para2d_local)
21+
{
22+
int DecomposedState = 0;
23+
ELPA_Solver es(false, comm, nband, para2d_local.get_row_size(), para2d_local.get_col_size(), para2d_local.desc);
24+
es.generalized_eigenvector(h_local, s_local, DecomposedState, ekb, wfc_2d);
25+
es.exit();
26+
}
3027

31-
void elpa_diag(MPI_Comm comm,
32-
const int nband,
33-
double* h_local,
34-
double* s_local,
35-
double* ekb,
36-
double* wfc_2d,
37-
Parallel_2D& para2d_local)
38-
{
39-
int DecomposedState = 0;
40-
ELPA_Solver es(true,
41-
comm,
42-
nband,
43-
para2d_local.get_row_size(),
44-
para2d_local.get_col_size(),
45-
para2d_local.desc);
46-
es.generalized_eigenvector(h_local, s_local, DecomposedState, ekb, wfc_2d);
47-
es.exit();
48-
}
28+
void elpa_diag(MPI_Comm comm,
29+
const int nband,
30+
double* h_local,
31+
double* s_local,
32+
double* ekb,
33+
double* wfc_2d,
34+
Parallel_2D& para2d_local)
35+
{
36+
int DecomposedState = 0;
37+
ELPA_Solver es(true, comm, nband, para2d_local.get_row_size(), para2d_local.get_col_size(), para2d_local.desc);
38+
es.generalized_eigenvector(h_local, s_local, DecomposedState, ekb, wfc_2d);
39+
es.exit();
40+
}
4941

50-
void elpa_diag(MPI_Comm comm,
51-
const int nband,
52-
std::complex<float>* h_local,
53-
std::complex<float>* s_local,
54-
float* ekb,
55-
std::complex<float>* wfc_2d,
56-
Parallel_2D& para2d_local)
57-
{
58-
std::cout << "Error: ELPA do not support single precision. " << std::endl;
59-
exit(1);
60-
}
42+
void elpa_diag(MPI_Comm comm,
43+
const int nband,
44+
std::complex<float>* h_local,
45+
std::complex<float>* s_local,
46+
float* ekb,
47+
std::complex<float>* wfc_2d,
48+
Parallel_2D& para2d_local)
49+
{
50+
std::cout << "Error: ELPA do not support single precision. " << std::endl;
51+
exit(1);
52+
}
6153

62-
void elpa_diag(MPI_Comm comm,
63-
const int nband,
64-
float* h_local,
65-
float* s_local,
66-
float* ekb,
67-
float* wfc_2d,
68-
Parallel_2D& para2d_local)
69-
{
70-
std::cout << "Error: ELPA do not support single precision. " << std::endl;
71-
exit(1);
72-
}
54+
void elpa_diag(MPI_Comm comm,
55+
const int nband,
56+
float* h_local,
57+
float* s_local,
58+
float* ekb,
59+
float* wfc_2d,
60+
Parallel_2D& para2d_local)
61+
{
62+
std::cout << "Error: ELPA do not support single precision. " << std::endl;
63+
exit(1);
64+
}
7365

7466
#endif
7567

76-
7768
#ifdef __MPI
7869

7970
template <typename T>
80-
void Diago_HS_para(
81-
T* h,
82-
T* s,
83-
const int lda,
84-
const int nband,
85-
typename GetTypeReal<T>::type *const ekb,
86-
T *const wfc,
87-
const MPI_Comm& comm,
88-
const int diag_subspace_method,
89-
const int block_size)
71+
void Diago_HS_para(T* h,
72+
T* s,
73+
const int lda,
74+
const int nband,
75+
typename GetTypeReal<T>::type* const ekb,
76+
T* const wfc,
77+
const MPI_Comm& comm,
78+
const int diag_subspace,
79+
const int block_size)
9080
{
91-
int myrank;
81+
int myrank = 0;
9282
MPI_Comm_rank(comm, &myrank);
9383
Parallel_2D para2d_global;
9484
Parallel_2D para2d_local;
95-
para2d_global.init(lda,lda,lda,comm);
85+
para2d_global.init(lda, lda, lda, comm);
9686

9787
int max_nb = block_size;
9888
if (block_size == 0)
@@ -113,88 +103,99 @@ void Diago_HS_para(
113103
}
114104

115105
// for genelpa, if the block size is too large that some cores have no data, then it will cause error.
116-
if (diag_subspace_method == 1)
106+
if (diag_subspace == 1)
117107
{
118108
if (max_nb * (std::max(para2d_global.dim0, para2d_global.dim1) - 1) >= lda)
119109
{
120110
max_nb = lda / std::max(para2d_global.dim0, para2d_global.dim1);
121111
}
122112
}
123-
124-
para2d_local.init(lda,lda,max_nb,comm);
113+
114+
para2d_local.init(lda, lda, max_nb, comm);
125115
std::vector<T> h_local(para2d_local.get_col_size() * para2d_local.get_row_size());
126116
std::vector<T> s_local(para2d_local.get_col_size() * para2d_local.get_row_size());
127117
std::vector<T> wfc_2d(para2d_local.get_col_size() * para2d_local.get_row_size());
128-
118+
129119
// distribute h and s to 2D
130-
Cpxgemr2d(lda,lda,h,1,1,para2d_global.desc,h_local.data(),1,1,para2d_local.desc,para2d_local.blacs_ctxt);
131-
Cpxgemr2d(lda,lda,s,1,1,para2d_global.desc,s_local.data(),1,1,para2d_local.desc,para2d_local.blacs_ctxt);
120+
Cpxgemr2d(lda, lda, h, 1, 1, para2d_global.desc, h_local.data(), 1, 1, para2d_local.desc, para2d_local.blacs_ctxt);
121+
Cpxgemr2d(lda, lda, s, 1, 1, para2d_global.desc, s_local.data(), 1, 1, para2d_local.desc, para2d_local.blacs_ctxt);
132122

133-
if (diag_subspace_method == 1)
123+
if (diag_subspace == 1)
134124
{
135-
#ifdef __ELPA
125+
#ifdef __ELPA
136126
elpa_diag(comm, nband, h_local.data(), s_local.data(), ekb, wfc_2d.data(), para2d_local);
137127
#else
138-
std::cout << "ERROR: try to use ELPA to solve the generalized eigenvalue problem, but ELPA is not compiled. " << std::endl;
128+
std::cout << "ERROR: try to use ELPA to solve the generalized eigenvalue problem, but ELPA is not compiled. "
129+
<< std::endl;
139130
exit(1);
140-
#endif
131+
#endif
141132
}
142-
else if (diag_subspace_method == 2)
133+
else if (diag_subspace == 2)
143134
{
144-
hsolver::pxxxgvx_diag(para2d_local.desc, para2d_local.get_row_size(), para2d_local.get_col_size(),nband, h_local.data(), s_local.data(), ekb, wfc_2d.data());
135+
hsolver::pxxxgvx_diag(para2d_local.desc,
136+
para2d_local.get_row_size(),
137+
para2d_local.get_col_size(),
138+
nband,
139+
h_local.data(),
140+
s_local.data(),
141+
ekb,
142+
wfc_2d.data());
145143
}
146-
else{
147-
std::cout << "Error: parallel diagonalization method is not supported. " << "diag_subspace_method = " << diag_subspace_method << std::endl;
144+
else
145+
{
146+
std::cout << "Error: parallel diagonalization method is not supported. " << "diag_subspace = " << diag_subspace
147+
<< std::endl;
148148
exit(1);
149149
}
150150

151151
// gather wfc
152-
Cpxgemr2d(lda,lda,wfc_2d.data(),1,1,para2d_local.desc,wfc,1,1,para2d_global.desc,para2d_local.blacs_ctxt);
152+
Cpxgemr2d(lda, lda, wfc_2d.data(), 1, 1, para2d_local.desc, wfc, 1, 1, para2d_global.desc, para2d_local.blacs_ctxt);
153153

154154
// free the context
155155
Cblacs_gridexit(para2d_local.blacs_ctxt);
156156
Cblacs_gridexit(para2d_global.blacs_ctxt);
157157
}
158158

159159
// template instantiation
160-
template void Diago_HS_para<double>(double* h,
161-
double* s,
162-
const int lda,
163-
const int nband,
164-
typename GetTypeReal<double>::type *const ekb,
165-
double *const wfc,
166-
const MPI_Comm& comm,
167-
const int diag_subspace_method,
168-
const int block_size);
169-
template void Diago_HS_para<std::complex<double>>(std::complex<double>* h,
170-
std::complex<double>* s,
171-
const int lda,
172-
const int nband,
173-
typename GetTypeReal<std::complex<double>>::type *const ekb,
174-
std::complex<double> *const wfc,
175-
const MPI_Comm& comm,
176-
const int diag_subspace_method,
177-
const int block_size);
178-
template void Diago_HS_para<float>(float* h,
179-
float* s,
180-
const int lda,
181-
const int nband,
182-
typename GetTypeReal<float>::type *const ekb,
183-
float *const wfc,
184-
const MPI_Comm& comm,
185-
const int diag_subspace_method,
160+
template void Diago_HS_para<double>(double* h,
161+
double* s,
162+
const int lda,
163+
const int nband,
164+
typename GetTypeReal<double>::type* const ekb,
165+
double* const wfc,
166+
const MPI_Comm& comm,
167+
const int diag_subspace,
186168
const int block_size);
187-
template void Diago_HS_para<std::complex<float>>(std::complex<float>* h,
188-
std::complex<float>* s,
189-
const int lda,
190-
const int nband,
191-
typename GetTypeReal<std::complex<float>>::type *const ekb,
192-
std::complex<float> *const wfc,
193-
const MPI_Comm& comm,
194-
const int diag_subspace_method,
195-
const int block_size);
196-
197169

170+
template void Diago_HS_para<std::complex<double>>(std::complex<double>* h,
171+
std::complex<double>* s,
172+
const int lda,
173+
const int nband,
174+
typename GetTypeReal<std::complex<double>>::type* const ekb,
175+
std::complex<double>* const wfc,
176+
const MPI_Comm& comm,
177+
const int diag_subspace,
178+
const int block_size);
179+
180+
template void Diago_HS_para<float>(float* h,
181+
float* s,
182+
const int lda,
183+
const int nband,
184+
typename GetTypeReal<float>::type* const ekb,
185+
float* const wfc,
186+
const MPI_Comm& comm,
187+
const int diag_subspace,
188+
const int block_size);
189+
190+
template void Diago_HS_para<std::complex<float>>(std::complex<float>* h,
191+
std::complex<float>* s,
192+
const int lda,
193+
const int nband,
194+
typename GetTypeReal<std::complex<float>>::type* const ekb,
195+
std::complex<float>* const wfc,
196+
const MPI_Comm& comm,
197+
const int diag_subspace,
198+
const int block_size);
198199

199200
#endif
200201

0 commit comments

Comments
 (0)