Skip to content

Commit 2716cef

Browse files
committed
Modify the hpsi_func in pyabacus to maintain definition consistency
1 parent 4fd5094 commit 2716cef

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

python/pyabacus/src/py_diago_david.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,23 +111,21 @@ class PyDiagoDavid
111111
auto hpsi_func = [mm_op] (
112112
std::complex<double> *psi_in,
113113
std::complex<double> *hpsi_out,
114-
const int nband_in,
115-
const int nbasis_in,
116-
const int band_index1,
117-
const int band_index2
114+
const int ldPsi,
115+
const int nvec
118116
) {
119117
// Note: numpy's py::array_t is row-major, but
120118
// our raw pointer-array is column-major
121-
py::array_t<std::complex<double>, py::array::f_style> psi({nbasis_in, band_index2 - band_index1 + 1});
119+
py::array_t<std::complex<double>, py::array::f_style> psi({ldPsi, nvec});
122120
py::buffer_info psi_buf = psi.request();
123121
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
124-
std::copy(psi_in + band_index1 * nbasis_in, psi_in + (band_index2 + 1) * nbasis_in, psi_ptr);
122+
std::copy(psi_in, psi_in + nvec * ldPsi, psi_ptr);
125123

126124
py::array_t<std::complex<double>, py::array::f_style> hpsi = mm_op(psi);
127125

128126
py::buffer_info hpsi_buf = hpsi.request();
129127
std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
130-
std::copy(hpsi_ptr, hpsi_ptr + (band_index2 - band_index1 + 1) * nbasis_in, hpsi_out);
128+
std::copy(hpsi_ptr, hpsi_ptr + nvec * ldPsi, hpsi_out);
131129
};
132130

133131
auto spsi_func = [this] (

0 commit comments

Comments
 (0)