Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,13 +449,15 @@ def _batched_qr(a, mode="reduced"):
a_t,
shape=(batch_size, m, m),
dtype=res_type,
order="C",
)
else:
mc = k
q = dpnp.empty_like(
a_t,
shape=(batch_size, n, m),
dtype=res_type,
order="C",
)

# use DPCTL tensor function to fill the matrix array `q[..., :n, :]`
Expand Down Expand Up @@ -2532,13 +2534,15 @@ def dpnp_qr(a, mode="reduced"):
a_t,
shape=(m, m),
dtype=res_type,
order="C",
)
else:
mc = k
q = dpnp.empty_like(
a_t,
shape=(n, m),
dtype=res_type,
order="C",
)

# use DPCTL tensor function to fill the matrix array `q[:n]`
Expand Down
14 changes: 13 additions & 1 deletion dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2372,12 +2372,24 @@ class TestQr:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize(
"shape",
[(2, 2), (3, 4), (5, 3), (16, 16), (2, 2, 2), (2, 4, 2), (2, 2, 4)],
[
(2, 1),
(2, 2),
(3, 4),
(5, 3),
(16, 16),
(3, 3, 1),
(2, 2, 2),
(2, 4, 2),
(2, 2, 4),
],
ids=[
"(2, 1)",
"(2, 2)",
"(3, 4)",
"(5, 3)",
"(16, 16)",
"(3, 3, 1)",
"(2, 2, 2)",
"(2, 4, 2)",
"(2, 2, 4)",
Expand Down
Loading