@@ -113,23 +113,21 @@ class PyDiagoDavSubspace
113113 auto hpsi_func = [mm_op] (
114114 std::complex <double > *psi_in,
115115 std::complex <double > *hpsi_out,
116- const int nband_in,
117- const int nbasis_in,
118- const int band_index1,
119- const int band_index2
116+ const int ldPsi,
117+ const int nvec
120118 ) {
121119 // Note: numpy's py::array_t is row-major, but
122120 // our raw pointer-array is column-major
123- py::array_t <std::complex <double >, py::array::f_style> psi ({nbasis_in , band_index2 - band_index1 + 1 });
121+ py::array_t <std::complex <double >, py::array::f_style> psi ({ldPsi , band_index2 - band_index1 + 1 });
124122 py::buffer_info psi_buf = psi.request ();
125123 std::complex <double >* psi_ptr = static_cast <std::complex <double >*>(psi_buf.ptr );
126- std::copy (psi_in + band_index1 * nbasis_in , psi_in + (band_index2 + 1 ) * nbasis_in , psi_ptr);
124+ std::copy (psi_in, psi_in + nvec * ldPsi , psi_ptr);
127125
128126 py::array_t <std::complex <double >, py::array::f_style> hpsi = mm_op (psi);
129127
130128 py::buffer_info hpsi_buf = hpsi.request ();
131129 std::complex <double >* hpsi_ptr = static_cast <std::complex <double >*>(hpsi_buf.ptr );
132- std::copy (hpsi_ptr, hpsi_ptr + (band_index2 - band_index1 + 1 ) * nbasis_in , hpsi_out);
130+ std::copy (hpsi_ptr, hpsi_ptr + nvec * ldPsi , hpsi_out);
133131 };
134132
135133 obj = std::make_unique<hsolver::Diago_DavSubspace<std::complex <double >, base_device::DEVICE_CPU>>(
0 commit comments