Skip to content

Commit eb8c58a

Browse files
Update test_empty_shapes for lu_solve()
1 parent f6d77fe commit eb8c58a

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

dpnp/tests/test_linalg.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2320,19 +2320,15 @@ def test_broadcast_rhs(self, b_shape):
23202320

23212321
assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5)
23222322

2323-
@pytest.mark.parametrize("shape", [(0, 0), (0, 5), (5, 5)])
2324-
@pytest.mark.parametrize("rhs_cols", [None, 0, 3])
2325-
def test_empty_shapes(self, shape, rhs_cols):
2326-
a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F")
2327-
if min(shape) > 0:
2328-
for i in range(min(shape)):
2323+
@pytest.mark.parametrize("a_shape", [(0, 0), (5, 5)])
2324+
@pytest.mark.parametrize("b_shape", [(0,), (0, 0), (0, 5)])
2325+
def test_empty_shapes(self, a_shape, b_shape):
2326+
a_dp = dpnp.empty(a_shape, dtype=dpnp.default_float_type(), order="F")
2327+
n = a_shape[0]
2328+
2329+
if n > 0:
2330+
for i in range(n):
23292331
a_dp[i, i] = a_dp.dtype.type(1.0)
2330-
2331-
n = shape[0]
2332-
if rhs_cols is None:
2333-
b_shape = (n,)
2334-
else:
2335-
b_shape = (n, rhs_cols)
23362332
b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F")
23372333

23382334
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
@@ -2537,7 +2533,7 @@ def test_diff_type(self, dtype_a, dtype_b, b_shape):
25372533
((0, 0, 0), (0, 0)),
25382534
],
25392535
)
2540-
def test_empty_inputs(self, a_shape, b_shape):
2536+
def test_empty_shapes(self, a_shape, b_shape):
25412537
a = dpnp.empty(a_shape, dtype=dpnp.default_float_type(), order="F")
25422538
b = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F")
25432539

0 commit comments

Comments
 (0)