@@ -398,7 +398,7 @@ def _batched_lu_factor(a, res_type):
398398 return (out_a , out_ipiv , out_dev_info )
399399
400400
401- def _batched_lu_factor_scipy (a , res_type ):
401+ def _batched_lu_factor_scipy (a , res_type ): # pylint: disable=too-many-locals
402402 """SciPy-compatible LU factorization for batched inputs."""
403403
404404 # TODO: Find out at which array sizes the best performance is obtained
@@ -414,6 +414,19 @@ def _batched_lu_factor_scipy(a, res_type):
414414 m , n = a .shape [- 2 :]
415415 k = min (m , n )
416416 orig_shape = a .shape
417+ batch_shape = orig_shape [:- 2 ]
418+
419+ # accommodate empty arrays
420+ if a .size == 0 :
421+ lu = dpnp .empty_like (a )
422+ piv = dpnp .empty (
423+ (* batch_shape , k ),
424+ dtype = dpnp .int64 ,
425+ usm_type = a_usm_type ,
426+ sycl_queue = a_sycl_queue ,
427+ )
428+ return lu , piv
429+
417430 # get 3d input arrays by reshape
418431 a = dpnp .reshape (a , (- 1 , m , n ))
419432 batch_size = a .shape [0 ]
@@ -468,7 +481,7 @@ def _batched_lu_factor_scipy(a, res_type):
468481 # Batch was moved to the last axis before the call.
469482 # Move it back to the front and reshape to the original shape.
470483 a_h = dpnp .moveaxis (a_h , - 1 , 0 ).reshape (orig_shape )
471- ipiv_h = ipiv_h .reshape ((* orig_shape [: - 2 ] , k ))
484+ ipiv_h = ipiv_h .reshape ((* batch_shape , k ))
472485
473486 if any (dev_info_h ):
474487 diag_nums = ", " .join (str (v ) for v in dev_info_h if v > 0 )
@@ -530,7 +543,7 @@ def _batched_lu_factor_scipy(a, res_type):
530543
531544 # Reshape the results back to their original shape
532545 out_a = dpnp .array (a_vecs ).reshape (orig_shape )
533- out_ipiv = dpnp .array (ipiv_vecs ).reshape ((* orig_shape [: - 2 ] , k ))
546+ out_ipiv = dpnp .array (ipiv_vecs ).reshape ((* batch_shape , k ))
534547
535548 diag_nums = ", " .join (
536549 str (v ) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0
@@ -2463,14 +2476,6 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24632476 a_sycl_queue = a .sycl_queue
24642477 a_usm_type = a .usm_type
24652478
2466- # accommodate empty arrays
2467- if a .size == 0 :
2468- lu = dpnp .empty_like (a )
2469- piv = dpnp .arange (
2470- 0 , dtype = dpnp .int64 , usm_type = a_usm_type , sycl_queue = a_sycl_queue
2471- )
2472- return lu , piv
2473-
24742479 if check_finite :
24752480 if not dpnp .isfinite (a ).all ():
24762481 raise ValueError ("array must not contain infs or NaNs" )
@@ -2480,6 +2485,14 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24802485 # so `overwrite_a` is ignored here
24812486 return _batched_lu_factor_scipy (a , res_type )
24822487
2488+ # accommodate empty arrays
2489+ if a .size == 0 :
2490+ lu = dpnp .empty_like (a )
2491+ piv = dpnp .arange (
2492+ 0 , dtype = dpnp .int64 , usm_type = a_usm_type , sycl_queue = a_sycl_queue
2493+ )
2494+ return lu , piv
2495+
24832496 _manager = dpu .SequentialOrderManager [a_sycl_queue ]
24842497 a_usm_arr = dpnp .get_usm_ndarray (a )
24852498
0 commit comments