Skip to content

Commit fcd693c

Browse files
Apply remarks
1 parent c061199 commit fcd693c

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

dpnp/tests/test_linalg.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2022,13 +2022,17 @@ class TestLuFactorBatched:
20222022
@staticmethod
20232023
def _apply_pivots_rows(A_dp, piv_dp):
20242024
m = A_dp.shape[0]
2025-
rows = dpnp.arange(m)
2026-
for i in range(int(piv_dp.shape[0])):
2027-
r = int(piv_dp[i].item())
2025+
2026+
if m == 0 or piv_dp.size == 0:
2027+
return A_dp
2028+
2029+
rows = list(range(m))
2030+
piv_np = dpnp.asnumpy(piv_dp)
2031+
for i, r in enumerate(piv_np):
20282032
if i != r:
2029-
tmp = rows[i].copy()
2030-
rows[i] = rows[r]
2031-
rows[r] = tmp
2033+
rows[i], rows[r] = rows[r], rows[i]
2034+
2035+
rows = dpnp.asarray(rows)
20322036
return A_dp[rows]
20332037

20342038
@staticmethod
@@ -2118,15 +2122,15 @@ def test_strided(self):
21182122
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
21192123

21202124
def test_singular_matrix(self):
2121-
a = dpnp.zeros((3, 2, 2), dtype=dpnp.float64)
2125+
a = dpnp.zeros((3, 2, 2), dtype=dpnp.default_float_type())
21222126
a[0] = dpnp.array([[1.0, 2.0], [2.0, 4.0]])
21232127
a[1] = dpnp.eye(2)
21242128
a[2] = dpnp.array([[1.0, 1.0], [1.0, 1.0]])
21252129
with pytest.warns(RuntimeWarning, match="Singular matrix"):
21262130
dpnp.linalg.lu_factor(a, check_finite=False)
21272131

21282132
def test_check_finite_raises(self):
2129-
a = dpnp.ones((2, 3, 3), dtype=dpnp.float64, order="F")
2133+
a = dpnp.ones((2, 3, 3), dtype=dpnp.default_float_type(), order="F")
21302134
a[1, 0, 0] = dpnp.nan
21312135
assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True)
21322136

0 commit comments

Comments
 (0)