Skip to content

Commit 4cf630d

Browse files
Simplify _apply_pivots_rows
1 parent 3276f92 commit 4cf630d

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

dpnp/tests/test_linalg.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,13 +1863,17 @@ class TestLuFactor:
18631863
@staticmethod
18641864
def _apply_pivots_rows(A_dp, piv_dp):
18651865
m = A_dp.shape[0]
1866-
rows = dpnp.arange(m)
1867-
for i in range(int(piv_dp.shape[0])):
1868-
r = int(piv_dp[i].item())
1866+
1867+
if m == 0 or piv_dp.size == 0:
1868+
return A_dp
1869+
1870+
rows = list(range(m))
1871+
piv_np = dpnp.asnumpy(piv_dp)
1872+
for i, r in enumerate(piv_np):
18691873
if i != r:
1870-
tmp = rows[i].copy()
1871-
rows[i] = rows[r]
1872-
rows[r] = tmp
1874+
rows[i], rows[r] = rows[r], rows[i]
1875+
1876+
rows = dpnp.asarray(rows)
18731877
return A_dp[rows]
18741878

18751879
@staticmethod

0 commit comments

Comments
 (0)