Skip to content

Commit bd460f0

Browse files
committed
add cg diagnolization method
1 parent 9fa76dc commit bd460f0

File tree

2 files changed

+34
-57
lines changed

2 files changed

+34
-57
lines changed

python/pyabacus/examples/diago_matrix.py

Lines changed: 21 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@ def gen_dense_mat(dim):
1717

1818
return h_mat
1919

20-
def calc_eig_dav(mat_file, method):
21-
algo = {
20+
def calc_eig_pyabacus(mat_file, method):
21+
dav = {
2222
'dav_subspace': hsolver.dav_subspace,
2323
'davidson': hsolver.davidson
2424
}
25+
cg = {
26+
'cg': hsolver.cg
27+
}
2528

2629
h_mat, nbasis, nband = load_mat(mat_file)
2730

@@ -30,43 +33,23 @@ def calc_eig_dav(mat_file, method):
3033
diag_elem = np.where(np.abs(diag_elem) < 1e-8, 1e-8, diag_elem)
3134
precond = 1.0 / np.abs(diag_elem)
3235

33-
def mm_op(x):
36+
def mvv_op(x):
3437
return h_mat.dot(x)
3538

36-
e, _ = algo[method](
37-
mm_op,
38-
v0,
39-
nbasis,
40-
nband,
41-
precond,
42-
dav_ndim=8,
43-
tol=1e-8
44-
)
45-
46-
print(f'eigenvalues calculated by pyabacus-{method} is: \n', e)
39+
if method in dav:
40+
algo = dav[method]
41+
# args: mvvop, init_v, dim, num_eigs, precondition, dav_ndim, tol, max_iter
42+
args = (mvv_op, v0, nbasis, nband, precond, 8, 1e-12, 5000)
43+
elif method in cg:
44+
algo = cg[method]
45+
# args: mvvop, init_v, dim, num_eigs, precondition, tol, max_iter
46+
args = (mvv_op, v0, nbasis, nband, precond, 1e-12, 5000)
47+
else:
48+
raise ValueError(f"Method {method} not available")
4749

48-
return e
50+
e, _ = algo(*args)
4951

50-
def calc_eig_cg(h_mat, num_eigs):
51-
dim = h_mat.shape[0]
52-
v0 = np.random.rand(dim, num_eigs)
53-
diag_elem = h_mat.diagonal()
54-
diag_elem = np.where(np.abs(diag_elem) < 1e-8, 1e-8, diag_elem)
55-
precond = 1.0 / np.abs(diag_elem)
56-
57-
def mm_op(x):
58-
return h_mat.dot(x)
59-
60-
e, _ = hsolver.cg(
61-
mm_op,
62-
v0,
63-
dim,
64-
num_eigs,
65-
precond,
66-
tol=1e-8
67-
)
68-
69-
print('eigenvalues calculated by pyabacus-cg is: \n', e)
52+
print(f'eigenvalues calculated by pyabacus-{method} is: \n', e)
7053

7154
return e
7255

@@ -78,31 +61,14 @@ def calc_eigsh(mat_file):
7861

7962
return e
8063

81-
def calc_eigh(h_mat, num_eigs):
82-
e, _ = scipy.linalg.eigh(h_mat)
83-
e = np.sort(e)
84-
print('eigenvalues calculated by scipy is: \n', e[:num_eigs])
85-
86-
return e
87-
8864
if __name__ == '__main__':
8965
mat_file = './Si2.mat'
90-
method = ['dav_subspace', 'davidson']
66+
method = ['dav_subspace', 'davidson', 'cg']
9167

9268
for m in method:
9369
print(f'\n====== Calculating eigenvalues using {m} method... ======')
94-
e_pyabacus = calc_eig_dav(mat_file, m)
70+
e_pyabacus = calc_eig_pyabacus(mat_file, m)
9571
e_scipy = calc_eigsh(mat_file)
9672

9773
print('eigenvalues difference: \n', e_pyabacus - e_scipy)
98-
99-
print("\n====== davidson and dav_subspace method Done! ======")
100-
print("\n====== CG method ======")
101-
102-
h_mat = gen_dense_mat(100)
103-
num_eigs = 8
104-
e_cg = calc_eig_cg(h_mat, num_eigs)
105-
e_scipy = calc_eigh(h_mat, num_eigs)
106-
107-
print('eigenvalues difference: \n', e_cg - e_scipy[:num_eigs])
108-
print("\n====== CG method Done! ======")
74+

python/pyabacus/src/hsolver/py_hsolver.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,21 @@ void bind_hsolver(py::module& m)
160160
161161
Parameters
162162
----------
163-
TO BE FILLED
163+
mm_op : Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
164+
The operator to be diagonalized, which is a function that takes a matrix as input
165+
and returns a matrix mv_op(X) = H * X as output.
166+
max_iter : int
167+
The maximum number of iterations.
168+
tol : double
169+
The tolerance for the convergence.
170+
need_subspace : bool
171+
Whether to use the subspace function.
172+
scf_type : bool
173+
Whether to use the SCF type, which is used to determine the
174+
convergence criterion.
164175
)pbdoc",
165176
"mm_op"_a,
166-
"diag_ndim"_a,
177+
"max_iter"_a,
167178
"tol"_a,
168179
"need_subspace"_a,
169180
"scf_type"_a,

0 commit comments

Comments
 (0)