Skip to content

Commit c061199

Browse files
Merge impl_lu_factor into impl_lu_factor_batch
2 parents 2ba1268 + 2f21467 commit c061199

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ repos:
88
pass_filenames: false
99
args: ["-r", "dpnp", "-lll"]
1010
- repo: https://github.com/pre-commit/pre-commit-hooks
11-
rev: v5.0.0
11+
rev: v6.0.0
1212
hooks:
1313
# Git
1414
- id: check-added-large-files

dpnp/backend/extensions/lapack/getrf.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
7171
T *a = reinterpret_cast<T *>(in_a);
7272

7373
const std::int64_t scratchpad_size =
74-
mkl_lapack::getrf_scratchpad_size<T>(exec_q, n, n, lda);
74+
mkl_lapack::getrf_scratchpad_size<T>(exec_q, m, n, lda);
7575
T *scratchpad = nullptr;
7676

7777
std::stringstream error_msg;
@@ -88,9 +88,9 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
8888
// It must be a non-negative integer.
8989
n, // The number of columns in the input matrix A (0 ≤ n).
9090
// It must be a non-negative integer.
91-
a, // Pointer to the input matrix A (n x n).
91+
a, // Pointer to the input matrix A (m x n).
9292
lda, // The leading dimension of matrix A.
93-
// It must be at least max(1, n).
93+
// It must be at least max(1, m).
9494
ipiv, // Pointer to the output array of pivot indices.
9595
scratchpad, // Pointer to scratchpad memory to be used by MKL
9696
// routine for storing intermediate results.

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ PYBIND11_MODULE(_lapack_impl, m)
135135

136136
m.def("_getrf", &lapack_ext::getrf,
137137
"Call `getrf` from OneMKL LAPACK library to return "
138-
"the LU factorization of a general n x n matrix",
138+
"the LU factorization of a general m x n matrix",
139139
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
140140
py::arg("dev_info"), py::arg("depends") = py::list());
141141

dpnp/tests/test_linalg.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,13 +1863,17 @@ class TestLuFactor:
18631863
@staticmethod
18641864
def _apply_pivots_rows(A_dp, piv_dp):
18651865
m = A_dp.shape[0]
1866-
rows = dpnp.arange(m)
1867-
for i in range(int(piv_dp.shape[0])):
1868-
r = int(piv_dp[i].item())
1866+
1867+
if m == 0 or piv_dp.size == 0:
1868+
return A_dp
1869+
1870+
rows = list(range(m))
1871+
piv_np = dpnp.asnumpy(piv_dp)
1872+
for i, r in enumerate(piv_np):
18691873
if i != r:
1870-
tmp = rows[i].copy()
1871-
rows[i] = rows[r]
1872-
rows[r] = tmp
1874+
rows[i], rows[r] = rows[r], rows[i]
1875+
1876+
rows = dpnp.asarray(rows)
18731877
return A_dp[rows]
18741878

18751879
@staticmethod
@@ -1955,7 +1959,7 @@ def test_overwrite_copy_special(self):
19551959
a2_orig = a2.copy()
19561960
a2.flags["WRITABLE"] = False
19571961

1958-
for a_dp, a_orig in zip((a1, a1), (a1_orig, a2_orig)):
1962+
for a_dp, a_orig in zip((a1, a2), (a1_orig, a2_orig)):
19591963
lu, piv = dpnp.linalg.lu_factor(
19601964
a_dp, overwrite_a=True, check_finite=False
19611965
)

0 commit comments

Comments
 (0)