@@ -111,23 +111,21 @@ class PyDiagoDavid
111111 auto hpsi_func = [mm_op] (
112112 std::complex <double > *psi_in,
113113 std::complex <double > *hpsi_out,
114- const int nband_in,
115- const int nbasis_in,
116- const int band_index1,
117- const int band_index2
114+ const int ldPsi,
115+ const int nvec
118116 ) {
119117 // Note: numpy's py::array_t is row-major, but
120118 // our raw pointer-array is column-major
121- py::array_t <std::complex <double >, py::array::f_style> psi ({nbasis_in, band_index2 - band_index1 + 1 });
119+ py::array_t <std::complex <double >, py::array::f_style> psi ({ldPsi, nvec });
122120 py::buffer_info psi_buf = psi.request ();
123121 std::complex <double >* psi_ptr = static_cast <std::complex <double >*>(psi_buf.ptr );
124- std::copy (psi_in + band_index1 * nbasis_in , psi_in + (band_index2 + 1 ) * nbasis_in , psi_ptr);
122+ std::copy (psi_in, psi_in + nvec * ldPsi , psi_ptr);
125123
126124 py::array_t <std::complex <double >, py::array::f_style> hpsi = mm_op (psi);
127125
128126 py::buffer_info hpsi_buf = hpsi.request ();
129127 std::complex <double >* hpsi_ptr = static_cast <std::complex <double >*>(hpsi_buf.ptr );
130- std::copy (hpsi_ptr, hpsi_ptr + (band_index2 - band_index1 + 1 ) * nbasis_in , hpsi_out);
128+ std::copy (hpsi_ptr, hpsi_ptr + nvec * ldPsi , hpsi_out);
131129 };
132130
133131 auto spsi_func = [this ] (
0 commit comments