Skip to content

Commit 97b6202

Browse files
Handle empty inputs correctly for dpnp.linalg.lu_factor()
1 parent 25850f4 commit 97b6202

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def _batched_lu_factor(a, res_type):
398398
return (out_a, out_ipiv, out_dev_info)
399399

400400

401-
def _batched_lu_factor_scipy(a, res_type):
401+
def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
402402
"""SciPy-compatible LU factorization for batched inputs."""
403403

404404
# TODO: Find out at which array sizes the best performance is obtained
@@ -414,6 +414,19 @@ def _batched_lu_factor_scipy(a, res_type):
414414
m, n = a.shape[-2:]
415415
k = min(m, n)
416416
orig_shape = a.shape
417+
batch_shape = orig_shape[:-2]
418+
419+
# accommodate empty arrays
420+
if a.size == 0:
421+
lu = dpnp.empty_like(a)
422+
piv = dpnp.empty(
423+
(*batch_shape, k),
424+
dtype=dpnp.int64,
425+
usm_type=a_usm_type,
426+
sycl_queue=a_sycl_queue,
427+
)
428+
return lu, piv
429+
417430
# get 3d input arrays by reshape
418431
a = dpnp.reshape(a, (-1, m, n))
419432
batch_size = a.shape[0]
@@ -468,7 +481,7 @@ def _batched_lu_factor_scipy(a, res_type):
468481
# Batch was moved to the last axis before the call.
469482
# Move it back to the front and reshape to the original shape.
470483
a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape)
471-
ipiv_h = ipiv_h.reshape((*orig_shape[:-2], k))
484+
ipiv_h = ipiv_h.reshape((*batch_shape, k))
472485

473486
if any(dev_info_h):
474487
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
@@ -530,7 +543,7 @@ def _batched_lu_factor_scipy(a, res_type):
530543

531544
# Reshape the results back to their original shape
532545
out_a = dpnp.array(a_vecs).reshape(orig_shape)
533-
out_ipiv = dpnp.array(ipiv_vecs).reshape((*orig_shape[:-2], k))
546+
out_ipiv = dpnp.array(ipiv_vecs).reshape((*batch_shape, k))
534547

535548
diag_nums = ", ".join(
536549
str(v) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0
@@ -2463,14 +2476,6 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24632476
a_sycl_queue = a.sycl_queue
24642477
a_usm_type = a.usm_type
24652478

2466-
# accommodate empty arrays
2467-
if a.size == 0:
2468-
lu = dpnp.empty_like(a)
2469-
piv = dpnp.arange(
2470-
0, dtype=dpnp.int64, usm_type=a_usm_type, sycl_queue=a_sycl_queue
2471-
)
2472-
return lu, piv
2473-
24742479
if check_finite:
24752480
if not dpnp.isfinite(a).all():
24762481
raise ValueError("array must not contain infs or NaNs")
@@ -2480,6 +2485,14 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24802485
# so `overwrite_a` is ignored here
24812486
return _batched_lu_factor_scipy(a, res_type)
24822487

2488+
# accommodate empty arrays
2489+
if a.size == 0:
2490+
lu = dpnp.empty_like(a)
2491+
piv = dpnp.arange(
2492+
0, dtype=dpnp.int64, usm_type=a_usm_type, sycl_queue=a_sycl_queue
2493+
)
2494+
return lu, piv
2495+
24832496
_manager = dpu.SequentialOrderManager[a_sycl_queue]
24842497
a_usm_arr = dpnp.get_usm_ndarray(a)
24852498

0 commit comments

Comments
 (0)