11#ifndef PYTHON_PYABACUS_SRC_PY_DIAGO_CG_HPP
22#define PYTHON_PYABACUS_SRC_PY_DIAGO_CG_HPP
33
4+ #include < complex>
5+ #include < functional>
6+
47#include < pybind11/pybind11.h>
58#include < pybind11/complex.h>
69#include < pybind11/functional.h>
710#include < pybind11/numpy.h>
811#include < pybind11/stl.h>
912
1013#include < ATen/core/tensor.h>
14+ #include < ATen/core/tensor_map.h>
1115#include < ATen/core/tensor_types.h>
1216
1317#include " module_hsolver/diago_cg.h"
@@ -20,7 +24,8 @@ namespace py_hsolver
2024
2125class PyDiagoCG
2226{
23- PyDiagoCG () { }
27+ public:
28+ PyDiagoCG (int dim, int num_eigs) : dim{dim}, num_eigs{num_eigs} { }
2429 PyDiagoCG (const PyDiagoCG&) = delete ;
2530 PyDiagoCG& operator =(const PyDiagoCG&) = delete ;
2631 PyDiagoCG (PyDiagoCG&& other)
@@ -47,9 +52,75 @@ class PyDiagoCG
4752 }
4853 }
4954
55+ void init_eig ()
56+ {
57+ eig = new ct::Tensor (ct::DataType::DT_DOUBLE, {num_eigs});
58+ eig->zero ();
59+ }
60+
61+ py::array_t <double > get_eig ()
62+ {
63+ py::array_t <double > eig_out (eig->NumElements ());
64+ py::buffer_info eig_buf = eig_out.request ();
65+ double * eig_out_ptr = static_cast <double *>(eig_buf.ptr );
66+
67+ if (eig == nullptr ) {
68+ throw std::runtime_error (" eig is not initialized" );
69+ }
70+ double * eig_ptr = eig->data <double >();
71+
72+ std::copy (eig_ptr, eig_ptr + eig->NumElements (), eig_out_ptr);
73+ return eig_out;
74+ }
75+
76+ void set_psi (py::array_t <std::complex <double >> psi_in)
77+ {
78+ py::buffer_info psi_buf = psi_in.request ();
79+ std::complex <double >* psi_ptr = static_cast <std::complex <double >*>(psi_buf.ptr );
80+
81+ psi = new ct::TensorMap (
82+ psi_ptr,
83+ ct::DataType::DT_COMPLEX_DOUBLE,
84+ ct::DeviceType::CpuDevice,
85+ ct::TensorShape ({num_eigs, dim})
86+ );
87+ }
88+
89+ py::array_t <std::complex <double >> get_psi ()
90+ {
91+ py::array_t <std::complex <double >> psi_out ({num_eigs, dim});
92+ py::buffer_info psi_buf = psi_out.request ();
93+ std::complex <double >* psi_out_ptr = static_cast <std::complex <double >*>(psi_buf.ptr );
94+
95+ if (psi == nullptr ) {
96+ throw std::runtime_error (" psi is not initialized" );
97+ }
98+ std::complex <double >* psi_ptr = psi->data <std::complex <double >>();
99+
100+ std::copy (psi_ptr, psi_ptr + psi->NumElements (), psi_out_ptr);
101+ return psi_out;
102+ }
103+
104+ void set_prec (py::array_t <double > prec_in)
105+ {
106+ py::buffer_info prec_buf = prec_in.request ();
107+ double * prec_ptr = static_cast <double *>(prec_buf.ptr );
108+
109+ prec = new ct::TensorMap (
110+ prec_ptr,
111+ ct::DataType::DT_DOUBLE,
112+ ct::DeviceType::CpuDevice,
113+ ct::TensorShape ({dim})
114+ );
115+ }
116+
50117 void diag (
51118 std::function<py::array_t <std::complex <double >>(py::array_t <std::complex <double >>)> mm_op,
52- bool scf_type
119+ int diag_ndim,
120+ double tol,
121+ bool need_subspace,
122+ bool scf_type,
123+ int nproc_in_pool = 1
53124 ) {
54125 const std::string basis_type = " pw" ;
55126 const std::string calculation = scf_type ? " scf" : " nscf" ;
@@ -58,18 +129,18 @@ class PyDiagoCG
58129 const auto ndim = psi_in.shape ().ndim ();
59130 REQUIRES_OK (ndim <= 2 , " dims of psi_in should be less than or equal to 2" );
60131 const int nvec = ndim == 1 ? 1 : psi_in.shape ().dim_size (0 );
61- const int ld_psi = ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 )
132+ const int ld_psi = ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 );
62133
63- // py::array_t<std::complex<double>> psi({ld_psi, nvec});
64- // py::buffer_info psi_buf = psi.request();
65- // std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
66- // std::copy(psi_in, psi_in + nvec * ld_psi, psi_ptr);
134+ py::array_t <std::complex <double >> psi ({ld_psi, nvec});
135+ py::buffer_info psi_buf = psi.request ();
136+ std::complex <double >* psi_ptr = static_cast <std::complex <double >*>(psi_buf.ptr );
137+ std::copy (psi_in. data <std:: complex < double >>() , psi_in. data <std:: complex < double >>() + nvec * ld_psi, psi_ptr);
67138
68- // py::array_t<std::complex<double>> hpsi = mm_op(psi);
139+ py::array_t <std::complex <double >> hpsi = mm_op (psi);
69140
70- // py::buffer_info hpsi_buf = hpsi.request();
71- // std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
72- // std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
141+ py::buffer_info hpsi_buf = hpsi.request ();
142+ std::complex <double >* hpsi_ptr = static_cast <std::complex <double >*>(hpsi_buf.ptr );
143+ std::copy (hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out. data <std:: complex < double >>() );
73144 };
74145
75146 auto subspace_func = [] (const ct::Tensor& psi_in, ct::Tensor& psi_out) { /* do nothing*/ };
@@ -87,14 +158,31 @@ class PyDiagoCG
87158 static_cast <size_t >(nrow * nbands)
88159 );
89160 };
161+
162+ cg = std::make_unique<hsolver::DiagoCG<std::complex <double >, base_device::DEVICE_CPU>>(
163+ basis_type,
164+ calculation,
165+ need_subspace,
166+ subspace_func,
167+ tol,
168+ diag_ndim,
169+ nproc_in_pool
170+ );
171+
172+ return cg->diag (hpsi_func, spsi_func, *psi, *eig, *prec);
90173 }
91174
92175private:
93176 base_device::DEVICE_CPU* ctx = {};
94177
178+ int dim;
179+ int num_eigs;
180+
95181 ct::Tensor* psi = nullptr ;
96182 ct::Tensor* eig = nullptr ;
183+ ct::Tensor* prec = nullptr ;
97184
185+ std::unique_ptr<hsolver::DiagoCG<std::complex <double >, base_device::DEVICE_CPU>> cg;
98186};
99187
100188} // namespace py_hsolver
0 commit comments