Skip to content

Commit 889871e

Browse files
committed
Modified the hpsi_func in pyabacus to maintain definition consistency
1 parent 9d4a7c9 commit 889871e

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

python/pyabacus/src/py_diago_dav_subspace.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,21 @@ class PyDiagoDavSubspace
113113
auto hpsi_func = [mm_op] (
114114
std::complex<double> *psi_in,
115115
std::complex<double> *hpsi_out,
116-
const int nband_in,
117-
const int nbasis_in,
118-
const int band_index1,
119-
const int band_index2
116+
const int ldPsi,
117+
const int nvec
120118
) {
121119
// Note: numpy's py::array_t is row-major, but
122120
// our raw pointer-array is column-major
123-
py::array_t<std::complex<double>, py::array::f_style> psi({nbasis_in, band_index2 - band_index1 + 1});
121+
py::array_t<std::complex<double>, py::array::f_style> psi({ldPsi, band_index2 - band_index1 + 1});
124122
py::buffer_info psi_buf = psi.request();
125123
std::complex<double>* psi_ptr = static_cast<std::complex<double>*>(psi_buf.ptr);
126-
std::copy(psi_in + band_index1 * nbasis_in, psi_in + (band_index2 + 1) * nbasis_in, psi_ptr);
124+
std::copy(psi_in, psi_in + nvec * ldPsi, psi_ptr);
127125

128126
py::array_t<std::complex<double>, py::array::f_style> hpsi = mm_op(psi);
129127

130128
py::buffer_info hpsi_buf = hpsi.request();
131129
std::complex<double>* hpsi_ptr = static_cast<std::complex<double>*>(hpsi_buf.ptr);
132-
std::copy(hpsi_ptr, hpsi_ptr + (band_index2 - band_index1 + 1) * nbasis_in, hpsi_out);
130+
std::copy(hpsi_ptr, hpsi_ptr + nvec * ldPsi, hpsi_out);
133131
};
134132

135133
obj = std::make_unique<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>>(

0 commit comments

Comments
 (0)