@@ -398,7 +398,7 @@ def _batched_lu_factor(a, res_type):
398
398
return (out_a , out_ipiv , out_dev_info )
399
399
400
400
401
- def _batched_lu_factor_scipy (a , res_type ):
401
+ def _batched_lu_factor_scipy (a , res_type ): # pylint: disable=too-many-locals
402
402
"""SciPy-compatible LU factorization for batched inputs."""
403
403
404
404
# TODO: Find out at which array sizes the best performance is obtained
@@ -414,6 +414,19 @@ def _batched_lu_factor_scipy(a, res_type):
414
414
m , n = a .shape [- 2 :]
415
415
k = min (m , n )
416
416
orig_shape = a .shape
417
+ batch_shape = orig_shape [:- 2 ]
418
+
419
+ # accommodate empty arrays
420
+ if a .size == 0 :
421
+ lu = dpnp .empty_like (a )
422
+ piv = dpnp .empty (
423
+ (* batch_shape , k ),
424
+ dtype = dpnp .int64 ,
425
+ usm_type = a_usm_type ,
426
+ sycl_queue = a_sycl_queue ,
427
+ )
428
+ return lu , piv
429
+
417
430
# get 3d input arrays by reshape
418
431
a = dpnp .reshape (a , (- 1 , m , n ))
419
432
batch_size = a .shape [0 ]
@@ -468,7 +481,7 @@ def _batched_lu_factor_scipy(a, res_type):
468
481
# Batch was moved to the last axis before the call.
469
482
# Move it back to the front and reshape to the original shape.
470
483
a_h = dpnp .moveaxis (a_h , - 1 , 0 ).reshape (orig_shape )
471
- ipiv_h = ipiv_h .reshape ((* orig_shape [: - 2 ] , k ))
484
+ ipiv_h = ipiv_h .reshape ((* batch_shape , k ))
472
485
473
486
if any (dev_info_h ):
474
487
diag_nums = ", " .join (str (v ) for v in dev_info_h if v > 0 )
@@ -530,7 +543,7 @@ def _batched_lu_factor_scipy(a, res_type):
530
543
531
544
# Reshape the results back to their original shape
532
545
out_a = dpnp .array (a_vecs ).reshape (orig_shape )
533
- out_ipiv = dpnp .array (ipiv_vecs ).reshape ((* orig_shape [: - 2 ] , k ))
546
+ out_ipiv = dpnp .array (ipiv_vecs ).reshape ((* batch_shape , k ))
534
547
535
548
diag_nums = ", " .join (
536
549
str (v ) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0
@@ -2463,14 +2476,6 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
2463
2476
a_sycl_queue = a .sycl_queue
2464
2477
a_usm_type = a .usm_type
2465
2478
2466
- # accommodate empty arrays
2467
- if a .size == 0 :
2468
- lu = dpnp .empty_like (a )
2469
- piv = dpnp .arange (
2470
- 0 , dtype = dpnp .int64 , usm_type = a_usm_type , sycl_queue = a_sycl_queue
2471
- )
2472
- return lu , piv
2473
-
2474
2479
if check_finite :
2475
2480
if not dpnp .isfinite (a ).all ():
2476
2481
raise ValueError ("array must not contain infs or NaNs" )
@@ -2480,6 +2485,14 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
2480
2485
# so `overwrite_a` is ignored here
2481
2486
return _batched_lu_factor_scipy (a , res_type )
2482
2487
2488
+ # accommodate empty arrays
2489
+ if a .size == 0 :
2490
+ lu = dpnp .empty_like (a )
2491
+ piv = dpnp .arange (
2492
+ 0 , dtype = dpnp .int64 , usm_type = a_usm_type , sycl_queue = a_sycl_queue
2493
+ )
2494
+ return lu , piv
2495
+
2483
2496
_manager = dpu .SequentialOrderManager [a_sycl_queue ]
2484
2497
a_usm_arr = dpnp .get_usm_ndarray (a )
2485
2498
0 commit comments