Skip to content

Commit b68d80b

Browse files
Update test_lu_solve in test_usm_type/sycl_queue.py
1 parent a13cd88 commit b68d80b

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

dpnp/tests/test_sycl_queue.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,11 +1612,16 @@ def test_lu_factor(self, data, device):
16121612
assert_sycl_queue_equal(param_queue, a.sycl_queue)
16131613

16141614
@pytest.mark.parametrize(
1615-
"b_data",
1616-
[[1.0, 2.0], numpy.empty((2, 0))],
1615+
"a_data, b_data",
1616+
[
1617+
([[1.0, 2.0], [3.0, 5.0]], [1.0, 2.0]),
1618+
([[1.0, 2.0], [3.0, 5.0]], numpy.empty((2, 0))),
1619+
([[[1.0, 2.0], [3.0, 5.0]]], [1.0, 2.0]),
1620+
([[[1.0, 2.0], [3.0, 5.0]]], numpy.empty((2, 0, 2))),
1621+
],
16171622
)
1618-
def test_lu_solve(self, b_data, device):
1619-
a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device)
1623+
def test_lu_solve(self, a_data, b_data, device):
1624+
a = dpnp.array(a_data, device=device)
16201625
lu, piv = dpnp.linalg.lu_factor(a)
16211626
b = dpnp.array(b_data, device=device)
16221627

dpnp/tests/test_usm_type.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,11 +1489,16 @@ def test_lu_factor(self, data, usm_type):
14891489

14901490
@pytest.mark.parametrize("usm_type_rhs", list_of_usm_types)
14911491
@pytest.mark.parametrize(
1492-
"b_data",
1493-
[[1.0, 2.0], numpy.empty((2, 0))],
1492+
"a_data, b_data",
1493+
[
1494+
([[1.0, 2.0], [3.0, 5.0]], [1.0, 2.0]),
1495+
([[1.0, 2.0], [3.0, 5.0]], numpy.empty((2, 0))),
1496+
([[[1.0, 2.0], [3.0, 5.0]]], [1.0, 2.0]),
1497+
([[[1.0, 2.0], [3.0, 5.0]]], numpy.empty((2, 0, 2))),
1498+
],
14941499
)
1495-
def test_lu_solve(self, b_data, usm_type, usm_type_rhs):
1496-
a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], usm_type=usm_type)
1500+
def test_lu_solve(self, a_data, b_data, usm_type, usm_type_rhs):
1501+
a = dpnp.array(a_data, usm_type=usm_type)
14971502
lu, piv = dpnp.linalg.lu_factor(a)
14981503
b = dpnp.array(b_data, usm_type=usm_type_rhs)
14991504

0 commit comments

Comments
 (0)