Skip to content

Commit e781beb

Browse files
Add sycl_queue and usm_type tests
1 parent 6fe2fe4 commit e781beb

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,11 +2303,15 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23032303
"""
23042304

23052305
res_type = _common_type(a)
2306+
a_sycl_queue = a.sycl_queue
2307+
a_usm_type = a.usm_type
23062308

23072309
# accommodate empty arrays
23082310
if a.size == 0:
23092311
lu = dpnp.empty_like(a)
2310-
piv = dpnp.arange(0, dtype=dpnp.int64)
2312+
piv = dpnp.arange(
2313+
0, dtype=dpnp.int64, usm_type=a_usm_type, sycl_queue=a_sycl_queue
2314+
)
23112315
return lu, piv
23122316

23132317
if check_finite:
@@ -2317,12 +2321,7 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23172321
if a.ndim > 2:
23182322
raise NotImplementedError("Batched matrices are not supported")
23192323

2320-
m, n = a.shape
2321-
2322-
a_sycl_queue = a.sycl_queue
2323-
a_usm_type = a.usm_type
23242324
_manager = dpu.SequentialOrderManager[a_sycl_queue]
2325-
23262325
a_usm_arr = dpnp.get_usm_ndarray(a)
23272326

23282327
# SciPy-compatible behavior
@@ -2345,6 +2344,8 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23452344
a_h = a
23462345
copy_ev = None
23472346

2347+
m, n = a.shape
2348+
23482349
ipiv_h = dpnp.empty(
23492350
min(m, n),
23502351
dtype=dpnp.int64,

dpnp/tests/test_sycl_queue.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,6 +1570,18 @@ def test_lstsq(self, m, n, nrhs, device):
15701570
assert_sycl_queue_equal(param_queue, a.sycl_queue)
15711571
assert_sycl_queue_equal(param_queue, b.sycl_queue)
15721572

1573+
@pytest.mark.parametrize(
1574+
"data",
1575+
[[[1.0, 2.0], [3.0, 5.0]], [[]]],
1576+
)
1577+
def test_lu_factor(self, data, device):
1578+
a = dpnp.array(data, device=device)
1579+
result = dpnp.linalg.lu_factor(a)
1580+
1581+
for param in result:
1582+
param_queue = param.sycl_queue
1583+
assert_sycl_queue_equal(param_queue, a.sycl_queue)
1584+
15731585
@pytest.mark.parametrize("n", [-1, 0, 1, 2, 3])
15741586
def test_matrix_power(self, n, device):
15751587
x = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device)

dpnp/tests/test_usm_type.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,18 @@ def test_lstsq(self, m, n, nrhs, usm_type, usm_type_other):
14491449
[usm_type, usm_type_other]
14501450
)
14511451

1452+
@pytest.mark.parametrize(
1453+
"data",
1454+
[[[1.0, 2.0], [3.0, 5.0]], [[]]],
1455+
)
1456+
def test_lu_factor(self, data, usm_type):
1457+
a = dpnp.array(data, usm_type=usm_type)
1458+
result = dpnp.linalg.lu_factor(a)
1459+
1460+
assert a.usm_type == usm_type
1461+
for param in result:
1462+
assert param.usm_type == a.usm_type
1463+
14521464
@pytest.mark.parametrize("n", [-1, 0, 1, 2, 3])
14531465
def test_matrix_power(self, n, usm_type):
14541466
a = dpnp.array([[1, 2], [3, 5]], usm_type=usm_type)

0 commit comments

Comments
 (0)