Skip to content

Commit 68d46d0

Browse files
Add _make_nonsingular_nd_np to TestLuFactorBatched
1 parent 5a33d43 commit 68d46d0

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

dpnp/tests/test_linalg.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,23 @@ def _apply_pivots_rows(A_dp, piv_dp):
20522052
rows = dpnp.asarray(rows)
20532053
return A_dp[rows]
20542054

2055+
@staticmethod
2056+
def _make_nonsingular_nd_np(shape, dtype, order):
2057+
A = generate_random_numpy_array(shape, dtype, order)
2058+
m, n = shape[-2], shape[-1]
2059+
k = min(m, n)
2060+
A3 = A.reshape((-1, m, n))
2061+
for B in A3:
2062+
for i in range(k):
2063+
off = numpy.sum(numpy.abs(B[i, :n])) - numpy.abs(B[i, i])
2064+
B[i, i] = A.dtype.type(off + 1.0)
2065+
2066+
A = A3.reshape(shape)
2067+
# A3.reshape returns an array in C order by default
2068+
if order != "C":
2069+
A = numpy.array(A, order=order)
2070+
return A
2071+
20552072
@staticmethod
20562073
def _split_lu(lu, m, n):
20572074
L = dpnp.tril(lu, k=-1)
@@ -2068,7 +2085,7 @@ def _split_lu(lu, m, n):
20682085
@pytest.mark.parametrize("order", ["C", "F"])
20692086
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
20702087
def test_lu_factor_batched(self, shape, order, dtype):
2071-
a_np = generate_random_numpy_array(shape, dtype, order)
2088+
a_np = self._make_nonsingular_nd_np(shape, dtype, order)
20722089
a_dp = dpnp.array(a_np, order=order)
20732090

20742091
lu, piv = dpnp.linalg.lu_factor(
@@ -2092,7 +2109,8 @@ def test_lu_factor_batched(self, shape, order, dtype):
20922109
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
20932110
@pytest.mark.parametrize("order", ["C", "F"])
20942111
def test_overwrite(self, dtype, order):
2095-
a_dp = dpnp.arange(2 * 2 * 3, dtype=dtype).reshape(3, 2, 2, order=order)
2112+
a_np = self._make_nonsingular_nd_np((3, 2, 2), dtype, order)
2113+
a_dp = dpnp.array(a_np, order=order)
20962114
a_dp_orig = a_dp.copy()
20972115
lu, piv = dpnp.linalg.lu_factor(
20982116
a_dp, overwrite_a=True, check_finite=False
@@ -2123,13 +2141,11 @@ def test_empty_inputs(self, shape):
21232141
assert piv.shape == (*shape[:-2], min(m, n))
21242142

21252143
def test_strided(self):
2126-
a = (
2127-
dpnp.arange(5 * 3 * 3, dtype=dpnp.default_float_type()).reshape(
2128-
5, 3, 3, order="F"
2129-
)
2130-
+ 0.1
2144+
a_np = self._make_nonsingular_nd_np(
2145+
(5, 3, 3), dpnp.default_float_type(), "F"
21312146
)
2132-
a_stride = a[::2]
2147+
a_dp = dpnp.array(a_np, order=order)
2148+
a_stride = a_dp[::2]
21332149
lu, piv = dpnp.linalg.lu_factor(a_stride, check_finite=False)
21342150
for i in range(a_stride.shape[0]):
21352151
L, U = self._split_lu(lu[i], 3, 3)

0 commit comments

Comments
 (0)