1+ #ifndef PYTHON_PYABACUS_SRC_PY_DIAGO_CG_HPP
2+ #define PYTHON_PYABACUS_SRC_PY_DIAGO_CG_HPP
3+
4+ #include < complex>
5+ #include < functional>
6+
7+ #include < pybind11/pybind11.h>
8+ #include < pybind11/complex.h>
9+ #include < pybind11/functional.h>
10+ #include < pybind11/numpy.h>
11+ #include < pybind11/stl.h>
12+
13+ #include < ATen/core/tensor.h>
14+ #include < ATen/core/tensor_map.h>
15+ #include < ATen/core/tensor_types.h>
16+
17+ #include " module_hsolver/diago_cg.h"
18+ #include " module_base/module_device/memory_op.h"
19+
20+ namespace py = pybind11;
21+
22+ namespace py_hsolver
23+ {
24+
25+ class PyDiagoCG
26+ {
27+ public:
28+ PyDiagoCG (int dim, int num_eigs) : dim{dim}, num_eigs{num_eigs} { }
29+ PyDiagoCG (const PyDiagoCG&) = delete ;
30+ PyDiagoCG& operator =(const PyDiagoCG&) = delete ;
31+ PyDiagoCG (PyDiagoCG&& other)
32+ {
33+ psi = other.psi ;
34+ other.psi = nullptr ;
35+
36+ eig = other.eig ;
37+ other.eig = nullptr ;
38+ }
39+
40+ ~PyDiagoCG ()
41+ {
42+ if (psi != nullptr )
43+ {
44+ delete psi;
45+ psi = nullptr ;
46+ }
47+
48+ if (eig != nullptr )
49+ {
50+ delete eig;
51+ eig = nullptr ;
52+ }
53+ }
54+
55+ void init_eig ()
56+ {
57+ eig = new ct::Tensor (ct::DataType::DT_DOUBLE, {num_eigs});
58+ eig->zero ();
59+ }
60+
61+ py::array_t <double > get_eig ()
62+ {
63+ py::array_t <double > eig_out (eig->NumElements ());
64+ py::buffer_info eig_buf = eig_out.request ();
65+ double * eig_out_ptr = static_cast <double *>(eig_buf.ptr );
66+
67+ if (eig == nullptr ) {
68+ throw std::runtime_error (" eig is not initialized" );
69+ }
70+ double * eig_ptr = eig->data <double >();
71+
72+ std::copy (eig_ptr, eig_ptr + eig->NumElements (), eig_out_ptr);
73+ return eig_out;
74+ }
75+
76+ void set_psi (py::array_t <std::complex <double >> psi_in)
77+ {
78+ py::buffer_info psi_buf = psi_in.request ();
79+ std::complex <double >* psi_ptr = static_cast <std::complex <double >*>(psi_buf.ptr );
80+
81+ psi = new ct::TensorMap (
82+ psi_ptr,
83+ ct::DataType::DT_COMPLEX_DOUBLE,
84+ ct::DeviceType::CpuDevice,
85+ ct::TensorShape ({num_eigs, dim})
86+ );
87+ }
88+
89+ py::array_t <std::complex <double >> get_psi ()
90+ {
91+ py::array_t <std::complex <double >> psi_out ({num_eigs, dim});
92+ py::buffer_info psi_buf = psi_out.request ();
93+ std::complex <double >* psi_out_ptr = static_cast <std::complex <double >*>(psi_buf.ptr );
94+
95+ if (psi == nullptr ) {
96+ throw std::runtime_error (" psi is not initialized" );
97+ }
98+ std::complex <double >* psi_ptr = psi->data <std::complex <double >>();
99+
100+ std::copy (psi_ptr, psi_ptr + psi->NumElements (), psi_out_ptr);
101+ return psi_out;
102+ }
103+
104+ void set_prec (py::array_t <double > prec_in)
105+ {
106+ py::buffer_info prec_buf = prec_in.request ();
107+ double * prec_ptr = static_cast <double *>(prec_buf.ptr );
108+
109+ prec = new ct::TensorMap (
110+ prec_ptr,
111+ ct::DataType::DT_DOUBLE,
112+ ct::DeviceType::CpuDevice,
113+ ct::TensorShape ({dim})
114+ );
115+ }
116+
117+ void diag (
118+ std::function<py::array_t <std::complex <double >>(py::array_t <std::complex <double >>)> mm_op,
119+ int diag_ndim,
120+ double tol,
121+ bool need_subspace,
122+ bool scf_type,
123+ int nproc_in_pool = 1
124+ ) {
125+ const std::string basis_type = " pw" ;
126+ const std::string calculation = scf_type ? " scf" : " nscf" ;
127+
128+ auto hpsi_func = [mm_op] (const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
129+ const auto ndim = psi_in.shape ().ndim ();
130+ REQUIRES_OK (ndim <= 2 , " dims of psi_in should be less than or equal to 2" );
131+ const int nvec = ndim == 1 ? 1 : psi_in.shape ().dim_size (0 );
132+ const int ld_psi = ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 );
133+
134+ // Note: numpy's py::array_t is row-major, and
135+ // our tensor-array is row-major
136+ py::array_t <std::complex <double >> psi ({ld_psi, nvec});
137+ py::buffer_info psi_buf = psi.request ();
138+ std::complex <double >* psi_ptr = static_cast <std::complex <double >*>(psi_buf.ptr );
139+ std::copy (psi_in.data <std::complex <double >>(), psi_in.data <std::complex <double >>() + nvec * ld_psi, psi_ptr);
140+
141+ py::array_t <std::complex <double >> hpsi = mm_op (psi);
142+
143+ py::buffer_info hpsi_buf = hpsi.request ();
144+ std::complex <double >* hpsi_ptr = static_cast <std::complex <double >*>(hpsi_buf.ptr );
145+ std::copy (hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out.data <std::complex <double >>());
146+ };
147+
148+ auto subspace_func = [] (const ct::Tensor& psi_in, ct::Tensor& psi_out) { /* do nothing*/ };
149+
150+ auto spsi_func = [this ] (const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
151+ const auto ndim = psi_in.shape ().ndim ();
152+ REQUIRES_OK (ndim <= 2 , " dims of psi_in should be less than or equal to 2" );
153+ const int nrow = ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 );
154+ const int nbands = ndim == 1 ? 1 : psi_in.shape ().dim_size (0 );
155+ syncmem_z2z_h2h_op ()(
156+ this ->ctx ,
157+ this ->ctx ,
158+ spsi_out.data <std::complex <double >>(),
159+ psi_in.data <std::complex <double >>(),
160+ static_cast <size_t >(nrow * nbands)
161+ );
162+ };
163+
164+ cg = std::make_unique<hsolver::DiagoCG<std::complex <double >, base_device::DEVICE_CPU>>(
165+ basis_type,
166+ calculation,
167+ need_subspace,
168+ subspace_func,
169+ tol,
170+ diag_ndim,
171+ nproc_in_pool
172+ );
173+
174+ cg->diag (hpsi_func, spsi_func, *psi, *eig, *prec);
175+ }
176+
177+ private:
178+ base_device::DEVICE_CPU* ctx = {};
179+
180+ int dim;
181+ int num_eigs;
182+
183+ ct::Tensor* psi = nullptr ;
184+ ct::Tensor* eig = nullptr ;
185+ ct::Tensor* prec = nullptr ;
186+
187+ std::unique_ptr<hsolver::DiagoCG<std::complex <double >, base_device::DEVICE_CPU>> cg;
188+ };
189+
190+ } // namespace py_hsolver
191+
192+ #endif
0 commit comments