Skip to content

Commit 9fa76dc

Browse files
committed
change the signature of cg
1 parent c5da614 commit 9fa76dc

File tree

4 files changed

+129
-41
lines changed

4 files changed

+129
-41
lines changed

python/pyabacus/examples/diago_matrix.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,20 @@ def load_mat(mat_file):
99

1010
return h_mat, nbasis, nband
1111

12-
def calc_eig_pyabacus(mat_file, method):
12+
def gen_dense_mat(dim):
13+
# generate a random symmetric and positive definite matrix
14+
h_mat = np.random.rand(dim, dim)
15+
h_mat = h_mat + h_mat.T
16+
h_mat = h_mat + dim * np.eye(dim)
17+
18+
return h_mat
19+
20+
def calc_eig_dav(mat_file, method):
1321
algo = {
1422
'dav_subspace': hsolver.dav_subspace,
15-
'davidson': hsolver.davidson,
16-
'cg': hsolver.cg
23+
'davidson': hsolver.davidson
1724
}
1825

19-
if method is not 'cg':
20-
ndim = 8
21-
else:
22-
ndim = 30
23-
2426
h_mat, nbasis, nband = load_mat(mat_file)
2527

2628
v0 = np.random.rand(nbasis, nband)
@@ -37,31 +39,70 @@ def mm_op(x):
3739
nbasis,
3840
nband,
3941
precond,
40-
ndim,
41-
1e-8 # tol
42+
dav_ndim=8,
43+
tol=1e-8
4244
)
4345

4446
print(f'eigenvalues calculated by pyabacus-{method} is: \n', e)
4547

4648
return e
4749

48-
def calc_eig_scipy(mat_file):
50+
def calc_eig_cg(h_mat, num_eigs):
51+
dim = h_mat.shape[0]
52+
v0 = np.random.rand(dim, num_eigs)
53+
diag_elem = h_mat.diagonal()
54+
diag_elem = np.where(np.abs(diag_elem) < 1e-8, 1e-8, diag_elem)
55+
precond = 1.0 / np.abs(diag_elem)
56+
57+
def mm_op(x):
58+
return h_mat.dot(x)
59+
60+
e, _ = hsolver.cg(
61+
mm_op,
62+
v0,
63+
dim,
64+
num_eigs,
65+
precond,
66+
tol=1e-8
67+
)
68+
69+
print('eigenvalues calculated by pyabacus-cg is: \n', e)
70+
71+
return e
72+
73+
def calc_eigsh(mat_file):
4974
h_mat, _, nband = load_mat(mat_file)
5075
e, _ = scipy.sparse.linalg.eigsh(h_mat, k=nband, which='SA', maxiter=1000)
5176
e = np.sort(e)
5277
print('eigenvalues calculated by scipy is: \n', e)
5378

5479
return e
5580

81+
def calc_eigh(h_mat, num_eigs):
82+
e, _ = scipy.linalg.eigh(h_mat)
83+
e = np.sort(e)
84+
print('eigenvalues calculated by scipy is: \n', e[:num_eigs])
85+
86+
return e
87+
5688
if __name__ == '__main__':
5789
mat_file = './Si2.mat'
58-
method = ['dav_subspace', 'davidson', 'cg']
90+
method = ['dav_subspace', 'davidson']
5991

6092
for m in method:
6193
print(f'\n====== Calculating eigenvalues using {m} method... ======')
62-
e_pyabacus = calc_eig_pyabacus(mat_file, m)
63-
e_scipy = calc_eig_scipy(mat_file)
94+
e_pyabacus = calc_eig_dav(mat_file, m)
95+
e_scipy = calc_eigsh(mat_file)
6496

6597
print('eigenvalues difference: \n', e_pyabacus - e_scipy)
98+
99+
print("\n====== davidson and dav_subspace method Done! ======")
100+
print("\n====== CG method ======")
101+
102+
h_mat = gen_dense_mat(100)
103+
num_eigs = 8
104+
e_cg = calc_eig_cg(h_mat, num_eigs)
105+
e_scipy = calc_eigh(h_mat, num_eigs)
66106

67-
107+
print('eigenvalues difference: \n', e_cg - e_scipy[:num_eigs])
108+
print("\n====== CG method Done! ======")

python/pyabacus/src/hsolver/py_diago_cg.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ class PyDiagoCG
131131
const int nvec = ndim == 1 ? 1 : psi_in.shape().dim_size(0);
132132
const int ld_psi = ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1);
133133

134+
// Note: numpy's py::array_t is row-major, and
135+
// our tensor-array is row-major
134136
py::array_t<std::complex<double>> psi({ld_psi, nvec});
135137
py::buffer_info psi_buf = psi.request();
136138
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
@@ -169,7 +171,7 @@ class PyDiagoCG
169171
nproc_in_pool
170172
);
171173

172-
return cg->diag(hpsi_func, spsi_func, *psi, *eig, *prec);
174+
cg->diag(hpsi_func, spsi_func, *psi, *eig, *prec);
173175
}
174176

175177
private:

python/pyabacus/src/pyabacus/hsolver/_hsolver.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def dav_subspace(
2828
init_v: NDArray[np.complex128],
2929
dim: int,
3030
num_eigs: int,
31-
pre_condition: NDArray[np.float64],
31+
precondition: NDArray[np.float64],
3232
dav_ndim: int = 2,
3333
tol: float = 1e-2,
3434
max_iter: int = 1000,
@@ -50,7 +50,7 @@ def dav_subspace(
5050
The number of basis, i.e. the number of rows/columns in the matrix.
5151
num_eigs : int
5252
The number of bands to calculate, i.e. the number of eigenvalues to calculate.
53-
pre_condition : NDArray[np.float64]
53+
precondition : NDArray[np.float64]
5454
The preconditioner.
5555
dav_ndim : int, optional
5656
The number of vectors in the subspace, by default 2.
@@ -94,7 +94,7 @@ def dav_subspace(
9494

9595
_ = _diago_obj_dav_subspace.diag(
9696
mvv_op,
97-
pre_condition,
97+
precondition,
9898
dav_ndim,
9999
tol,
100100
max_iter,
@@ -114,7 +114,7 @@ def davidson(
114114
init_v: NDArray[np.complex128],
115115
dim: int,
116116
num_eigs: int,
117-
pre_condition: NDArray[np.float64],
117+
precondition: NDArray[np.float64],
118118
dav_ndim: int = 2,
119119
tol: float = 1e-2,
120120
max_iter: int = 1000,
@@ -135,7 +135,7 @@ def davidson(
135135
The number of basis, i.e. the number of rows/columns in the matrix.
136136
num_eigs : int
137137
The number of bands to calculate, i.e. the number of eigenvalues to calculate.
138-
pre_condition : NDArray[np.float64]
138+
precondition : NDArray[np.float64]
139139
The preconditioner.
140140
dav_ndim : int, optional
141141
The number of vectors in the subspace, by default 2.
@@ -167,7 +167,7 @@ def davidson(
167167

168168
_ = _diago_obj_david.diag(
169169
mvv_op,
170-
pre_condition,
170+
precondition,
171171
dav_ndim,
172172
tol,
173173
max_iter,
@@ -185,12 +185,50 @@ def cg(
185185
init_v: NDArray[np.complex128],
186186
dim: int,
187187
num_eigs: int,
188-
pre_condition: NDArray[np.float64],
189-
diag_ndim: int = 2,
188+
precondition: NDArray[np.float64],
190189
tol: float = 1e-2,
190+
max_iter: int = 1000,
191191
need_subspace: bool = False,
192-
scf_type: bool = False
192+
scf_type: bool = False,
193+
nproc_in_pool: int = 1
193194
) -> Tuple[NDArray[np.float64], NDArray[np.complex128]]:
195+
""" A function to diagonalize a matrix using the Conjugate Gradient method.
196+
197+
Parameters
198+
----------
199+
mvv_op : Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
200+
The operator to be diagonalized, which is a function that takes a set of
201+
vectors X = [x1, ..., xN] as input and returns a matrix(vector block)
202+
mvv_op(X) = H * X ([Hx1, ..., HxN]) as output.
203+
init_v : NDArray[np.complex128]
204+
The initial guess for the eigenvectors.
205+
dim : int
206+
The number of basis, i.e. the number of rows/columns in the matrix.
207+
num_eigs : int
208+
The number of bands to calculate, i.e. the number of eigenvalues to calculate.
209+
precondition : NDArray[np.float64]
210+
The preconditioner.
211+
max_iter : int, optional
212+
The maximum number of iterations, by default 1000.
213+
tol : float, optional
214+
The tolerance for the convergence, by default 1e-2.
215+
need_subspace : bool, optional
216+
Whether to use subspace function, by default False.
217+
scf_type : bool, optional
218+
Indicates whether the calculation is a self-consistent field (SCF) calculation.
219+
If True, the initial precision of eigenvalue calculation can be coarse.
220+
If False, it indicates a non-self-consistent field (non-SCF) calculation,
221+
where high precision in eigenvalue calculation is required from the start.
222+
nproc_in_pool : int, optional
223+
The number of processes in the pool, by default 1.
224+
225+
Returns
226+
-------
227+
e : NDArray[np.float64]
228+
The eigenvalues.
229+
v : NDArray[np.complex128]
230+
The eigenvectors corresponding to the eigenvalues.
231+
"""
194232
if not callable(mvv_op):
195233
raise TypeError("mvv_op must be a callable object.")
196234

@@ -204,17 +242,18 @@ def cg(
204242
_diago_obj_cg.set_psi(init_v)
205243
_diago_obj_cg.init_eig()
206244

207-
_diago_obj_cg.set_prec(pre_condition)
245+
_diago_obj_cg.set_prec(precondition)
208246

209-
_ = _diago_obj_cg.diag(
247+
_diago_obj_cg.diag(
210248
mvv_op,
211-
diag_ndim,
249+
max_iter,
212250
tol,
213251
need_subspace,
214252
scf_type,
253+
nproc_in_pool
215254
)
216255

217-
e = _diago_obj_cg.get_eigenvalue()
256+
e = _diago_obj_cg.get_eig()
218257
v = _diago_obj_cg.get_psi()
219258

220259
return e, v

python/pyabacus/tests/test_hsolver.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
import scipy
77

88
def diag_pyabacus(h_sparse, nband, method):
9-
algo = {
9+
dav = {
1010
'dav_subspace': hsolver.dav_subspace,
1111
'davidson': hsolver.davidson
1212
}
13+
cg = {
14+
'cg': hsolver.cg
15+
}
1316
def mm_op(x):
1417
return h_sparse.dot(x)
1518

@@ -21,16 +24,16 @@ def mm_op(x):
2124
diag_elem = np.where(np.abs(diag_elem) < 1e-8, 1e-8, diag_elem)
2225
precond = 1.0 / np.abs(diag_elem)
2326

24-
e, _ = algo[method](
25-
mm_op,
26-
v0,
27-
nbasis,
28-
nband,
29-
precond,
30-
dav_ndim=8,
31-
tol=1e-12,
32-
max_iter=5000
33-
)
27+
if method in dav:
28+
algo = dav[method]
29+
args = (mm_op, v0, nbasis, nband, precond, 8, 1e-12, 5000)
30+
elif method in cg:
31+
algo = cg[method]
32+
args = (mm_op, v0, nbasis, nband, precond, 1e-12, 5000)
33+
else:
34+
raise ValueError(f"Method {method} not available")
35+
36+
e, _ = algo(*args)
3437

3538
return e
3639

@@ -40,7 +43,8 @@ def diag_eigsh(h_sparse, nband):
4043

4144
@pytest.mark.parametrize("method", [
4245
('dav_subspace'),
43-
('davidson')
46+
('davidson'),
47+
('cg')
4448
])
4549
def test_random_matrix_diag(method):
4650
np.random.seed(12)
@@ -55,8 +59,10 @@ def test_random_matrix_diag(method):
5559
@pytest.mark.parametrize("file_name, nband, atol, method", [
5660
('./test_diag/Si2.mat', 16, 1e-8, 'dav_subspace'),
5761
('./test_diag/Si2.mat', 16, 1e-8, 'davidson'),
62+
('./test_diag/Si2.mat', 16, 1e-8, 'cg'),
5863
('./test_diag/Na5.mat', 16, 1e-8, 'dav_subspace'),
5964
('./test_diag/Na5.mat', 16, 1e-8, 'davidson'),
65+
('./test_diag/Na5.mat', 16, 1e-8, 'cg'),
6066
])
6167
def test_diag(file_name, nband, atol, method):
6268
h_sparse = scipy.io.loadmat(file_name)['Problem']['A'][0, 0]

0 commit comments

Comments
 (0)