Skip to content

Commit 5f9a1c6

Browse files
authored
Feature: Add CG algorithm to pyabacus.hsolver (#5398)
* fix some typos in `_hsolver.py` * fix some bugs caused by #5134 * Refactor hsolver module and remove unused code * refactor the structure of pythonization source code * fix some bug * Refactor __getattr__ function in __init__.py to handle attribute errors * fix some bugs * Add CONTRIBUTING.md to facilitate contributing to pyabacus project * fix typos * Update CONTRIBUTING.md * Update CONTRIBUTING.md * update README.md and CONTRIBUTING.md * update README.md * update CONTRIBUTING.md * update CONTRIBUTING.md * fix a bug caused by tuple in python3.8 * update * add basic framework for diagoCG * add cd to hsolver * change the signature of cg * add cg diagnolization method * add __all__ * remove unused code in example
1 parent 628bdc6 commit 5f9a1c6

File tree

7 files changed

+368
-38
lines changed

7 files changed

+368
-38
lines changed

python/pyabacus/examples/diago_matrix.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ def load_mat(mat_file):
1010
return h_mat, nbasis, nband
1111

1212
def calc_eig_pyabacus(mat_file, method):
13-
algo = {
13+
dav = {
1414
'dav_subspace': hsolver.dav_subspace,
1515
'davidson': hsolver.davidson
1616
}
17+
cg = {
18+
'cg': hsolver.cg
19+
}
1720

1821
h_mat, nbasis, nband = load_mat(mat_file)
1922

@@ -22,25 +25,27 @@ def calc_eig_pyabacus(mat_file, method):
2225
diag_elem = np.where(np.abs(diag_elem) < 1e-8, 1e-8, diag_elem)
2326
precond = 1.0 / np.abs(diag_elem)
2427

25-
def mm_op(x):
28+
def mvv_op(x):
2629
return h_mat.dot(x)
2730

28-
e, _ = algo[method](
29-
mm_op,
30-
v0,
31-
nbasis,
32-
nband,
33-
precond,
34-
dav_ndim=8,
35-
tol=1e-8,
36-
max_iter=1000
37-
)
31+
if method in dav:
32+
algo = dav[method]
33+
# args: mvvop, init_v, dim, num_eigs, precondition, dav_ndim, tol, max_iter
34+
args = (mvv_op, v0, nbasis, nband, precond, 8, 1e-12, 5000)
35+
elif method in cg:
36+
algo = cg[method]
37+
# args: mvvop, init_v, dim, num_eigs, precondition, tol, max_iter
38+
args = (mvv_op, v0, nbasis, nband, precond, 1e-12, 5000)
39+
else:
40+
raise ValueError(f"Method {method} not available")
41+
42+
e, _ = algo(*args)
3843

3944
print(f'eigenvalues calculated by pyabacus-{method} is: \n', e)
4045

4146
return e
4247

43-
def calc_eig_scipy(mat_file):
48+
def calc_eigsh(mat_file):
4449
h_mat, _, nband = load_mat(mat_file)
4550
e, _ = scipy.sparse.linalg.eigsh(h_mat, k=nband, which='SA', maxiter=1000)
4651
e = np.sort(e)
@@ -50,13 +55,12 @@ def calc_eig_scipy(mat_file):
5055

5156
if __name__ == '__main__':
5257
mat_file = './Si2.mat'
53-
method = ['dav_subspace', 'davidson']
58+
method = ['dav_subspace', 'davidson', 'cg']
5459

5560
for m in method:
5661
print(f'\n====== Calculating eigenvalues using {m} method... ======')
5762
e_pyabacus = calc_eig_pyabacus(mat_file, m)
58-
e_scipy = calc_eig_scipy(mat_file)
63+
e_scipy = calc_eigsh(mat_file)
5964

6065
print('eigenvalues difference: \n', e_pyabacus - e_scipy)
61-
62-
66+

python/pyabacus/src/hsolver/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
list(APPEND _diago
33
${HSOLVER_PATH}/diago_dav_subspace.cpp
44
${HSOLVER_PATH}/diago_david.cpp
5+
${HSOLVER_PATH}/diago_cg.cpp
56
${HSOLVER_PATH}/diag_const_nums.cpp
67
${HSOLVER_PATH}/diago_iter_assist.cpp
78

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
#ifndef PYTHON_PYABACUS_SRC_PY_DIAGO_CG_HPP
2+
#define PYTHON_PYABACUS_SRC_PY_DIAGO_CG_HPP
3+
4+
#include <complex>
5+
#include <functional>
6+
7+
#include <pybind11/pybind11.h>
8+
#include <pybind11/complex.h>
9+
#include <pybind11/functional.h>
10+
#include <pybind11/numpy.h>
11+
#include <pybind11/stl.h>
12+
13+
#include <ATen/core/tensor.h>
14+
#include <ATen/core/tensor_map.h>
15+
#include <ATen/core/tensor_types.h>
16+
17+
#include "module_hsolver/diago_cg.h"
18+
#include "module_base/module_device/memory_op.h"
19+
20+
namespace py = pybind11;
21+
22+
namespace py_hsolver
23+
{
24+
25+
class PyDiagoCG
26+
{
27+
public:
28+
PyDiagoCG(int dim, int num_eigs) : dim{dim}, num_eigs{num_eigs} { }
29+
PyDiagoCG(const PyDiagoCG&) = delete;
30+
PyDiagoCG& operator=(const PyDiagoCG&) = delete;
31+
PyDiagoCG(PyDiagoCG&& other)
32+
{
33+
psi = other.psi;
34+
other.psi = nullptr;
35+
36+
eig = other.eig;
37+
other.eig = nullptr;
38+
}
39+
40+
~PyDiagoCG()
41+
{
42+
if (psi != nullptr)
43+
{
44+
delete psi;
45+
psi = nullptr;
46+
}
47+
48+
if (eig != nullptr)
49+
{
50+
delete eig;
51+
eig = nullptr;
52+
}
53+
}
54+
55+
void init_eig()
56+
{
57+
eig = new ct::Tensor(ct::DataType::DT_DOUBLE, {num_eigs});
58+
eig->zero();
59+
}
60+
61+
py::array_t<double> get_eig()
62+
{
63+
py::array_t<double> eig_out(eig->NumElements());
64+
py::buffer_info eig_buf = eig_out.request();
65+
double* eig_out_ptr = static_cast<double*>(eig_buf.ptr);
66+
67+
if (eig == nullptr) {
68+
throw std::runtime_error("eig is not initialized");
69+
}
70+
double* eig_ptr = eig->data<double>();
71+
72+
std::copy(eig_ptr, eig_ptr + eig->NumElements(), eig_out_ptr);
73+
return eig_out;
74+
}
75+
76+
void set_psi(py::array_t<std::complex<double>> psi_in)
77+
{
78+
py::buffer_info psi_buf = psi_in.request();
79+
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
80+
81+
psi = new ct::TensorMap(
82+
psi_ptr,
83+
ct::DataType::DT_COMPLEX_DOUBLE,
84+
ct::DeviceType::CpuDevice,
85+
ct::TensorShape({num_eigs, dim})
86+
);
87+
}
88+
89+
py::array_t<std::complex<double>> get_psi()
90+
{
91+
py::array_t<std::complex<double>> psi_out({num_eigs, dim});
92+
py::buffer_info psi_buf = psi_out.request();
93+
std::complex<double>* psi_out_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
94+
95+
if (psi == nullptr) {
96+
throw std::runtime_error("psi is not initialized");
97+
}
98+
std::complex<double>* psi_ptr = psi->data<std::complex<double>>();
99+
100+
std::copy(psi_ptr, psi_ptr + psi->NumElements(), psi_out_ptr);
101+
return psi_out;
102+
}
103+
104+
void set_prec(py::array_t<double> prec_in)
105+
{
106+
py::buffer_info prec_buf = prec_in.request();
107+
double* prec_ptr = static_cast<double*>(prec_buf.ptr);
108+
109+
prec = new ct::TensorMap(
110+
prec_ptr,
111+
ct::DataType::DT_DOUBLE,
112+
ct::DeviceType::CpuDevice,
113+
ct::TensorShape({dim})
114+
);
115+
}
116+
117+
void diag(
118+
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
119+
int diag_ndim,
120+
double tol,
121+
bool need_subspace,
122+
bool scf_type,
123+
int nproc_in_pool = 1
124+
) {
125+
const std::string basis_type = "pw";
126+
const std::string calculation = scf_type ? "scf" : "nscf";
127+
128+
auto hpsi_func = [mm_op] (const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
129+
const auto ndim = psi_in.shape().ndim();
130+
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
131+
const int nvec = ndim == 1 ? 1 : psi_in.shape().dim_size(0);
132+
const int ld_psi = ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1);
133+
134+
// Note: numpy's py::array_t is row-major, and
135+
// our tensor-array is row-major
136+
py::array_t<std::complex<double>> psi({ld_psi, nvec});
137+
py::buffer_info psi_buf = psi.request();
138+
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
139+
std::copy(psi_in.data<std::complex<double>>(), psi_in.data<std::complex<double>>() + nvec * ld_psi, psi_ptr);
140+
141+
py::array_t<std::complex<double>> hpsi = mm_op(psi);
142+
143+
py::buffer_info hpsi_buf = hpsi.request();
144+
std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
145+
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out.data<std::complex<double>>());
146+
};
147+
148+
auto subspace_func = [] (const ct::Tensor& psi_in, ct::Tensor& psi_out) { /*do nothing*/ };
149+
150+
auto spsi_func = [this] (const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
151+
const auto ndim = psi_in.shape().ndim();
152+
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
153+
const int nrow = ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1);
154+
const int nbands = ndim == 1 ? 1 : psi_in.shape().dim_size(0);
155+
syncmem_z2z_h2h_op()(
156+
this->ctx,
157+
this->ctx,
158+
spsi_out.data<std::complex<double>>(),
159+
psi_in.data<std::complex<double>>(),
160+
static_cast<size_t>(nrow * nbands)
161+
);
162+
};
163+
164+
cg = std::make_unique<hsolver::DiagoCG<std::complex<double>, base_device::DEVICE_CPU>>(
165+
basis_type,
166+
calculation,
167+
need_subspace,
168+
subspace_func,
169+
tol,
170+
diag_ndim,
171+
nproc_in_pool
172+
);
173+
174+
cg->diag(hpsi_func, spsi_func, *psi, *eig, *prec);
175+
}
176+
177+
private:
178+
base_device::DEVICE_CPU* ctx = {};
179+
180+
int dim;
181+
int num_eigs;
182+
183+
ct::Tensor* psi = nullptr;
184+
ct::Tensor* eig = nullptr;
185+
ct::Tensor* prec = nullptr;
186+
187+
std::unique_ptr<hsolver::DiagoCG<std::complex<double>, base_device::DEVICE_CPU>> cg;
188+
};
189+
190+
} // namespace py_hsolver
191+
192+
#endif

python/pyabacus/src/hsolver/py_hsolver.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "./py_diago_dav_subspace.hpp"
1313
#include "./py_diago_david.hpp"
14+
#include "./py_diago_cg.hpp"
1415

1516
namespace py = pybind11;
1617
using namespace pybind11::literals;
@@ -144,6 +145,55 @@ void bind_hsolver(py::module& m)
144145
.def("get_eigenvalue", &py_hsolver::PyDiagoDavid::get_eigenvalue, R"pbdoc(
145146
Get the eigenvalues.
146147
)pbdoc");
148+
149+
py::class_<py_hsolver::PyDiagoCG>(m, "diago_cg")
150+
.def(py::init<int, int>(), R"pbdoc(
151+
Constructor of diago_cg, a class for diagonalizing
152+
a linear operator using the Conjugate Gradient Method.
153+
154+
This class serves as a backend computation class. The interface
155+
for invoking this class is a function defined in _hsolver.py,
156+
which uses this class to perform the calculations.
157+
)pbdoc")
158+
.def("diag", &py_hsolver::PyDiagoCG::diag, R"pbdoc(
159+
Diagonalize the linear operator using the Conjugate Gradient Method.
160+
161+
Parameters
162+
----------
163+
mm_op : Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
164+
The operator to be diagonalized, which is a function that takes a matrix as input
165+
and returns a matrix mv_op(X) = H * X as output.
166+
max_iter : int
167+
The maximum number of iterations.
168+
tol : double
169+
The tolerance for the convergence.
170+
need_subspace : bool
171+
Whether to use the subspace function.
172+
scf_type : bool
173+
Whether to use the SCF type, which is used to determine the
174+
convergence criterion.
175+
)pbdoc",
176+
"mm_op"_a,
177+
"max_iter"_a,
178+
"tol"_a,
179+
"need_subspace"_a,
180+
"scf_type"_a,
181+
"nproc_in_pool"_a)
182+
.def("init_eig", &py_hsolver::PyDiagoCG::init_eig, R"pbdoc(
183+
Initialize the eigenvalues.
184+
)pbdoc")
185+
.def("get_eig", &py_hsolver::PyDiagoCG::get_eig, R"pbdoc(
186+
Get the eigenvalues.
187+
)pbdoc")
188+
.def("set_psi", &py_hsolver::PyDiagoCG::set_psi, R"pbdoc(
189+
Set the eigenvectors.
190+
)pbdoc", "psi_in"_a)
191+
.def("get_psi", &py_hsolver::PyDiagoCG::get_psi, R"pbdoc(
192+
Get the eigenvectors.
193+
)pbdoc")
194+
.def("set_prec", &py_hsolver::PyDiagoCG::set_prec, R"pbdoc(
195+
Set the preconditioner.
196+
)pbdoc", "prec_in"_a);
147197
}
148198

149199
PYBIND11_MODULE(_hsolver_pack, m)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from __future__ import annotations
22
from ._hsolver import *
33

4-
__all__ = ["diag_comm_info", "dav_subspace", "davidson"]
4+
__all__ = ["diag_comm_info", "dav_subspace", "davidson", "cg"]

0 commit comments

Comments
 (0)