Skip to content

Commit c5da614

Browse files
committed
add cd to hsolver
1 parent 18f52c4 commit c5da614

File tree

5 files changed

+189
-18
lines changed

5 files changed

+189
-18
lines changed

python/pyabacus/examples/diago_matrix.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,15 @@ def load_mat(mat_file):
1212
def calc_eig_pyabacus(mat_file, method):
1313
algo = {
1414
'dav_subspace': hsolver.dav_subspace,
15-
'davidson': hsolver.davidson
15+
'davidson': hsolver.davidson,
16+
'cg': hsolver.cg
1617
}
1718

19+
if method is not 'cg':
20+
ndim = 8
21+
else:
22+
ndim = 30
23+
1824
h_mat, nbasis, nband = load_mat(mat_file)
1925

2026
v0 = np.random.rand(nbasis, nband)
@@ -31,9 +37,8 @@ def mm_op(x):
3137
nbasis,
3238
nband,
3339
precond,
34-
dav_ndim=8,
35-
tol=1e-8,
36-
max_iter=1000
40+
ndim,
41+
1e-8 # tol
3742
)
3843

3944
print(f'eigenvalues calculated by pyabacus-{method} is: \n', e)
@@ -50,7 +55,7 @@ def calc_eig_scipy(mat_file):
5055

5156
if __name__ == '__main__':
5257
mat_file = './Si2.mat'
53-
method = ['dav_subspace', 'davidson']
58+
method = ['dav_subspace', 'davidson', 'cg']
5459

5560
for m in method:
5661
print(f'\n====== Calculating eigenvalues using {m} method... ======')

python/pyabacus/src/hsolver/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
list(APPEND _diago
33
${HSOLVER_PATH}/diago_dav_subspace.cpp
44
${HSOLVER_PATH}/diago_david.cpp
5+
${HSOLVER_PATH}/diago_cg.cpp
56
${HSOLVER_PATH}/diag_const_nums.cpp
67
${HSOLVER_PATH}/diago_iter_assist.cpp
78

python/pyabacus/src/hsolver/py_diago_cg.hpp

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
#ifndef PYTHON_PYABACUS_SRC_PY_DIAGO_CG_HPP
22
#define PYTHON_PYABACUS_SRC_PY_DIAGO_CG_HPP
33

4+
#include <complex>
5+
#include <functional>
6+
47
#include <pybind11/pybind11.h>
58
#include <pybind11/complex.h>
69
#include <pybind11/functional.h>
710
#include <pybind11/numpy.h>
811
#include <pybind11/stl.h>
912

1013
#include <ATen/core/tensor.h>
14+
#include <ATen/core/tensor_map.h>
1115
#include <ATen/core/tensor_types.h>
1216

1317
#include "module_hsolver/diago_cg.h"
@@ -20,7 +24,8 @@ namespace py_hsolver
2024

2125
class PyDiagoCG
2226
{
23-
PyDiagoCG() { }
27+
public:
28+
PyDiagoCG(int dim, int num_eigs) : dim{dim}, num_eigs{num_eigs} { }
2429
PyDiagoCG(const PyDiagoCG&) = delete;
2530
PyDiagoCG& operator=(const PyDiagoCG&) = delete;
2631
PyDiagoCG(PyDiagoCG&& other)
@@ -47,9 +52,75 @@ class PyDiagoCG
4752
}
4853
}
4954

55+
void init_eig()
56+
{
57+
eig = new ct::Tensor(ct::DataType::DT_DOUBLE, {num_eigs});
58+
eig->zero();
59+
}
60+
61+
py::array_t<double> get_eig()
62+
{
63+
py::array_t<double> eig_out(eig->NumElements());
64+
py::buffer_info eig_buf = eig_out.request();
65+
double* eig_out_ptr = static_cast<double*>(eig_buf.ptr);
66+
67+
if (eig == nullptr) {
68+
throw std::runtime_error("eig is not initialized");
69+
}
70+
double* eig_ptr = eig->data<double>();
71+
72+
std::copy(eig_ptr, eig_ptr + eig->NumElements(), eig_out_ptr);
73+
return eig_out;
74+
}
75+
76+
void set_psi(py::array_t<std::complex<double>> psi_in)
77+
{
78+
py::buffer_info psi_buf = psi_in.request();
79+
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
80+
81+
psi = new ct::TensorMap(
82+
psi_ptr,
83+
ct::DataType::DT_COMPLEX_DOUBLE,
84+
ct::DeviceType::CpuDevice,
85+
ct::TensorShape({num_eigs, dim})
86+
);
87+
}
88+
89+
py::array_t<std::complex<double>> get_psi()
90+
{
91+
py::array_t<std::complex<double>> psi_out({num_eigs, dim});
92+
py::buffer_info psi_buf = psi_out.request();
93+
std::complex<double>* psi_out_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
94+
95+
if (psi == nullptr) {
96+
throw std::runtime_error("psi is not initialized");
97+
}
98+
std::complex<double>* psi_ptr = psi->data<std::complex<double>>();
99+
100+
std::copy(psi_ptr, psi_ptr + psi->NumElements(), psi_out_ptr);
101+
return psi_out;
102+
}
103+
104+
void set_prec(py::array_t<double> prec_in)
105+
{
106+
py::buffer_info prec_buf = prec_in.request();
107+
double* prec_ptr = static_cast<double*>(prec_buf.ptr);
108+
109+
prec = new ct::TensorMap(
110+
prec_ptr,
111+
ct::DataType::DT_DOUBLE,
112+
ct::DeviceType::CpuDevice,
113+
ct::TensorShape({dim})
114+
);
115+
}
116+
50117
void diag(
51118
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
52-
bool scf_type
119+
int diag_ndim,
120+
double tol,
121+
bool need_subspace,
122+
bool scf_type,
123+
int nproc_in_pool = 1
53124
) {
54125
const std::string basis_type = "pw";
55126
const std::string calculation = scf_type ? "scf" : "nscf";
@@ -58,18 +129,18 @@ class PyDiagoCG
58129
const auto ndim = psi_in.shape().ndim();
59130
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
60131
const int nvec = ndim == 1 ? 1 : psi_in.shape().dim_size(0);
61-
const int ld_psi = ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1)
132+
const int ld_psi = ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1);
62133

63-
// py::array_t<std::complex<double>> psi({ld_psi, nvec});
64-
// py::buffer_info psi_buf = psi.request();
65-
// std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
66-
// std::copy(psi_in, psi_in + nvec * ld_psi, psi_ptr);
134+
py::array_t<std::complex<double>> psi({ld_psi, nvec});
135+
py::buffer_info psi_buf = psi.request();
136+
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
137+
std::copy(psi_in.data<std::complex<double>>(), psi_in.data<std::complex<double>>() + nvec * ld_psi, psi_ptr);
67138

68-
// py::array_t<std::complex<double>> hpsi = mm_op(psi);
139+
py::array_t<std::complex<double>> hpsi = mm_op(psi);
69140

70-
// py::buffer_info hpsi_buf = hpsi.request();
71-
// std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
72-
// std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
141+
py::buffer_info hpsi_buf = hpsi.request();
142+
std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
143+
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out.data<std::complex<double>>());
73144
};
74145

75146
auto subspace_func = [] (const ct::Tensor& psi_in, ct::Tensor& psi_out) { /*do nothing*/ };
@@ -87,14 +158,31 @@ class PyDiagoCG
87158
static_cast<size_t>(nrow * nbands)
88159
);
89160
};
161+
162+
cg = std::make_unique<hsolver::DiagoCG<std::complex<double>, base_device::DEVICE_CPU>>(
163+
basis_type,
164+
calculation,
165+
need_subspace,
166+
subspace_func,
167+
tol,
168+
diag_ndim,
169+
nproc_in_pool
170+
);
171+
172+
return cg->diag(hpsi_func, spsi_func, *psi, *eig, *prec);
90173
}
91174

92175
private:
93176
base_device::DEVICE_CPU* ctx = {};
94177

178+
int dim;
179+
int num_eigs;
180+
95181
ct::Tensor* psi = nullptr;
96182
ct::Tensor* eig = nullptr;
183+
ct::Tensor* prec = nullptr;
97184

185+
std::unique_ptr<hsolver::DiagoCG<std::complex<double>, base_device::DEVICE_CPU>> cg;
98186
};
99187

100188
} // namespace py_hsolver

python/pyabacus/src/hsolver/py_hsolver.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "./py_diago_dav_subspace.hpp"
1313
#include "./py_diago_david.hpp"
14+
#include "./py_diago_cg.hpp"
1415

1516
namespace py = pybind11;
1617
using namespace pybind11::literals;
@@ -144,6 +145,44 @@ void bind_hsolver(py::module& m)
144145
.def("get_eigenvalue", &py_hsolver::PyDiagoDavid::get_eigenvalue, R"pbdoc(
145146
Get the eigenvalues.
146147
)pbdoc");
148+
149+
py::class_<py_hsolver::PyDiagoCG>(m, "diago_cg")
150+
.def(py::init<int, int>(), R"pbdoc(
151+
Constructor of diago_cg, a class for diagonalizing
152+
a linear operator using the Conjugate Gradient Method.
153+
154+
This class serves as a backend computation class. The interface
155+
for invoking this class is a function defined in _hsolver.py,
156+
which uses this class to perform the calculations.
157+
)pbdoc")
158+
.def("diag", &py_hsolver::PyDiagoCG::diag, R"pbdoc(
159+
Diagonalize the linear operator using the Conjugate Gradient Method.
160+
161+
Parameters
162+
----------
163+
TO BE FILLED
164+
)pbdoc",
165+
"mm_op"_a,
166+
"diag_ndim"_a,
167+
"tol"_a,
168+
"need_subspace"_a,
169+
"scf_type"_a,
170+
"nproc_in_pool"_a)
171+
.def("init_eig", &py_hsolver::PyDiagoCG::init_eig, R"pbdoc(
172+
Initialize the eigenvalues.
173+
)pbdoc")
174+
.def("get_eig", &py_hsolver::PyDiagoCG::get_eig, R"pbdoc(
175+
Get the eigenvalues.
176+
)pbdoc")
177+
.def("set_psi", &py_hsolver::PyDiagoCG::set_psi, R"pbdoc(
178+
Set the eigenvectors.
179+
)pbdoc", "psi_in"_a)
180+
.def("get_psi", &py_hsolver::PyDiagoCG::get_psi, R"pbdoc(
181+
Get the eigenvectors.
182+
)pbdoc")
183+
.def("set_prec", &py_hsolver::PyDiagoCG::set_prec, R"pbdoc(
184+
Set the preconditioner.
185+
)pbdoc", "prec_in"_a);
147186
}
148187

149188
PYBIND11_MODULE(_hsolver_pack, m)

python/pyabacus/src/pyabacus/hsolver/_hsolver.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Tuple, List, Union, Callable
1010

1111
from ._hsolver_pack import diag_comm_info as _diag_comm_info
12-
from ._hsolver_pack import diago_dav_subspace, diago_david
12+
from ._hsolver_pack import diago_dav_subspace, diago_david, diago_cg
1313

1414
class diag_comm_info(_diag_comm_info):
1515
def __init__(self, rank: int, nproc: int):
@@ -179,4 +179,42 @@ def davidson(
179179
v = _diago_obj_david.get_psi()
180180

181181
return e, v
182-
182+
183+
def cg(
184+
mvv_op: Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
185+
init_v: NDArray[np.complex128],
186+
dim: int,
187+
num_eigs: int,
188+
pre_condition: NDArray[np.float64],
189+
diag_ndim: int = 2,
190+
tol: float = 1e-2,
191+
need_subspace: bool = False,
192+
scf_type: bool = False
193+
) -> Tuple[NDArray[np.float64], NDArray[np.complex128]]:
194+
if not callable(mvv_op):
195+
raise TypeError("mvv_op must be a callable object.")
196+
197+
if init_v.ndim != 1 or init_v.dtype != np.complex128:
198+
# the shape of init_v is (num_eigs, dim) = (dim, num_eigs).T
199+
if init_v.ndim == 2:
200+
init_v = init_v.T
201+
init_v = init_v.flatten().astype(np.complex128, order='C')
202+
203+
_diago_obj_cg = diago_cg(dim, num_eigs)
204+
_diago_obj_cg.set_psi(init_v)
205+
_diago_obj_cg.init_eig()
206+
207+
_diago_obj_cg.set_prec(pre_condition)
208+
209+
_ = _diago_obj_cg.diag(
210+
mvv_op,
211+
diag_ndim,
212+
tol,
213+
need_subspace,
214+
scf_type,
215+
)
216+
217+
e = _diago_obj_cg.get_eigenvalue()
218+
v = _diago_obj_cg.get_psi()
219+
220+
return e, v

0 commit comments

Comments
 (0)