Skip to content

Commit 3276f92

Browse files
Apply remarks
1 parent 25f7ab3 commit 3276f92

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

dpnp/backend/extensions/lapack/getrf.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
7171
T *a = reinterpret_cast<T *>(in_a);
7272

7373
const std::int64_t scratchpad_size =
74-
mkl_lapack::getrf_scratchpad_size<T>(exec_q, n, n, lda);
74+
mkl_lapack::getrf_scratchpad_size<T>(exec_q, m, n, lda);
7575
T *scratchpad = nullptr;
7676

7777
std::stringstream error_msg;
@@ -88,9 +88,9 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
8888
// It must be a non-negative integer.
8989
n, // The number of columns in the input matrix A (0 ≤ n).
9090
// It must be a non-negative integer.
91-
a, // Pointer to the input matrix A (n x n).
91+
a, // Pointer to the input matrix A (m x n).
9292
lda, // The leading dimension of matrix A.
93-
// It must be at least max(1, n).
93+
// It must be at least max(1, m).
9494
ipiv, // Pointer to the output array of pivot indices.
9595
scratchpad, // Pointer to scratchpad memory to be used by MKL
9696
// routine for storing intermediate results.

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ PYBIND11_MODULE(_lapack_impl, m)
135135

136136
m.def("_getrf", &lapack_ext::getrf,
137137
"Call `getrf` from OneMKL LAPACK library to return "
138-
"the LU factorization of a general n x n matrix",
138+
"the LU factorization of a general m x n matrix",
139139
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
140140
py::arg("dev_info"), py::arg("depends") = py::list());
141141

dpnp/tests/test_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1955,7 +1955,7 @@ def test_overwrite_copy_special(self):
19551955
a2_orig = a2.copy()
19561956
a2.flags["WRITABLE"] = False
19571957

1958-
for a_dp, a_orig in zip((a1, a1), (a1_orig, a2_orig)):
1958+
for a_dp, a_orig in zip((a1, a2), (a1_orig, a2_orig)):
19591959
lu, piv = dpnp.linalg.lu_factor(
19601960
a_dp, overwrite_a=True, check_finite=False
19611961
)

0 commit comments

Comments
 (0)