Skip to content

Commit 8a17f67

Browse files
Raise NotImplementedError for bathed matrices
1 parent 63138e9 commit 8a17f67

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,10 +2314,10 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23142314
if not dpnp.isfinite(a).all():
23152315
raise ValueError("array must not contain infs or NaNs")
23162316

2317-
# if a.ndim > 2:
2318-
# return _batched_lu_factor_scipy(a, res_type, overwrite_a=overwrite_a)
2317+
if a.ndim > 2:
2318+
raise NotImplementedError("Batched matrices are not supported")
23192319

2320-
n = a.shape[-2]
2320+
m, n = a.shape
23212321

23222322
a_sycl_queue = a.sycl_queue
23232323
a_usm_type = a.usm_type
@@ -2346,7 +2346,7 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23462346
copy_ev = None
23472347

23482348
ipiv_h = dpnp.empty(
2349-
n,
2349+
min(m, n),
23502350
dtype=dpnp.int64,
23512351
order="C",
23522352
usm_type=a_usm_type,

0 commit comments

Comments
 (0)