@@ -398,6 +398,159 @@ def _batched_lu_factor(a, res_type):
398398 return (out_a , out_ipiv , out_dev_info )
399399
400400
401+ def _batched_lu_factor_scipy (a , res_type ):
402+ """SciPy-compatible LU factorization for batched inputs."""
403+
404+ # TODO: Find out at which array sizes the best performance is obtained
405+ # getrf_batch implementation shows slow results with large arrays on GPU.
406+ # Use getrf_batch only on CPU.
407+ # On GPU call getrf for each two-dimensional array by loop
408+ use_batch = a .sycl_device .has_aspect_cpu
409+
410+ a_sycl_queue = a .sycl_queue
411+ a_usm_type = a .usm_type
412+ _manager = dpu .SequentialOrderManager [a_sycl_queue ]
413+
414+ m , n = a .shape [- 2 :]
415+ k = min (m , n )
416+ orig_shape = a .shape
417+ # get 3d input arrays by reshape
418+ a = dpnp .reshape (a , (- 1 , m , n ))
419+ batch_size = a .shape [0 ]
420+
421+ if use_batch :
422+ # Reorder the elements by moving the last two axes of `a` to the front
423+ # to match fortran-like array order which is assumed by getrf_batch
424+ a = dpnp .moveaxis (a , 0 , - 1 )
425+
426+ a_usm_arr = dpnp .get_usm_ndarray (a )
427+
428+ # `a` must be copied because getrf_batch destroys the input matrix
429+ a_h = dpnp .empty_like (a , order = "F" , dtype = res_type )
430+ ipiv_h = dpnp .empty (
431+ (batch_size , k ),
432+ dtype = dpnp .int64 ,
433+ order = "C" ,
434+ usm_type = a_usm_type ,
435+ sycl_queue = a_sycl_queue ,
436+ )
437+
438+ dev_info_h = [0 ] * batch_size
439+
440+ ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
441+ src = a_usm_arr ,
442+ dst = a_h .get_array (),
443+ sycl_queue = a_sycl_queue ,
444+ depends = _manager .submitted_events ,
445+ )
446+ _manager .add_event_pair (ht_ev , copy_ev )
447+
448+ ipiv_stride = k
449+ a_stride = a_h .strides [- 1 ]
450+
451+ # Call the LAPACK extension function _getrf_batch
452+ # to perform LU decomposition of a batch of general matrices
453+ ht_ev , getrf_ev = li ._getrf_batch (
454+ a_sycl_queue ,
455+ a_h .get_array (),
456+ ipiv_h .get_array (),
457+ dev_info_h ,
458+ m ,
459+ n ,
460+ a_stride ,
461+ ipiv_stride ,
462+ batch_size ,
463+ depends = [copy_ev ],
464+ )
465+ _manager .add_event_pair (ht_ev , getrf_ev )
466+
467+ # getrf_batch expects `a` in Fortran order and overwrites it.
468+ # Batch was moved to the last axis before the call.
469+ # Move it back to the front and reshape to the original shape.
470+ a_h = dpnp .moveaxis (a_h , - 1 , 0 ).reshape (orig_shape )
471+ ipiv_h = ipiv_h .reshape ((* orig_shape [:- 2 ], k ))
472+
473+ if any (dev_info_h ):
474+ diag_nums = ", " .join (str (v ) for v in dev_info_h if v > 0 )
475+ warn (
476+ f"Diagonal number { diag_nums } are exactly zero. "
477+ "Singular matrix." ,
478+ RuntimeWarning ,
479+ stacklevel = 2 ,
480+ )
481+
482+ # MKL lapack uses 1-origin while SciPy uses 0-origin
483+ ipiv_h -= 1
484+
485+ # Return a tuple containing the factorized matrix 'a_h',
486+ # pivot indices 'ipiv_h'
487+ return (a_h , ipiv_h )
488+
489+ a_usm_arr = dpnp .get_usm_ndarray (a )
490+
491+ # Initialize lists for storing arrays and events for each batch
492+ a_vecs = [None ] * batch_size
493+ ipiv_vecs = [None ] * batch_size
494+ dev_info_vecs = [None ] * batch_size
495+ dep_evs = _manager .submitted_events
496+
497+ # Process each batch
498+ for i in range (batch_size ):
499+ # Copy each 2D slice to a new array because getrf will destroy
500+ # the input matrix
501+ a_vecs [i ] = dpnp .empty_like (a [i ], order = "F" , dtype = res_type )
502+ ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
503+ src = a_usm_arr [i ],
504+ dst = a_vecs [i ].get_array (),
505+ sycl_queue = a_sycl_queue ,
506+ depends = dep_evs ,
507+ )
508+ _manager .add_event_pair (ht_ev , copy_ev )
509+
510+ ipiv_vecs [i ] = dpnp .empty (
511+ (k ,),
512+ dtype = dpnp .int64 ,
513+ order = "C" ,
514+ usm_type = a_usm_type ,
515+ sycl_queue = a_sycl_queue ,
516+ )
517+
518+ dev_info_vecs [i ] = [0 ]
519+
520+ # Call the LAPACK extension function _getrf
521+ # to perform LU decomposition on each batch in 'a_vecs[i]'
522+ ht_ev , getrf_ev = li ._getrf (
523+ a_sycl_queue ,
524+ a_vecs [i ].get_array (),
525+ ipiv_vecs [i ].get_array (),
526+ dev_info_vecs [i ],
527+ depends = [copy_ev ],
528+ )
529+ _manager .add_event_pair (ht_ev , getrf_ev )
530+
531+ # Reshape the results back to their original shape
532+ out_a = dpnp .array (a_vecs ).reshape (orig_shape )
533+ out_ipiv = dpnp .array (ipiv_vecs ).reshape ((* orig_shape [:- 2 ], k ))
534+
535+ diag_nums = ", " .join (
536+ str (v ) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0
537+ )
538+
539+ if diag_nums :
540+ warn (
541+ f"Diagonal number { diag_nums } are exactly zero. Singular matrix." ,
542+ RuntimeWarning ,
543+ stacklevel = 2 ,
544+ )
545+
546+ # MKL lapack uses 1-origin while SciPy uses 0-origin
547+ out_ipiv -= 1
548+
549+ # Return a tuple containing the factorized matrix 'out_a',
550+ # pivot indices 'out_ipiv'
551+ return (out_a , out_ipiv )
552+
553+
401554def _batched_solve (a , b , exec_q , res_usm_type , res_type ):
402555 """
403556 _batched_solve(a, b, exec_q, res_usm_type, res_type)
@@ -2323,7 +2476,9 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23232476 raise ValueError ("array must not contain infs or NaNs" )
23242477
23252478 if a .ndim > 2 :
2326- raise NotImplementedError ("Batched matrices are not supported" )
2479+ # SciPy always copies each 2D slice,
2480+ # so `overwrite_a` is ignored here
2481+ return _batched_lu_factor_scipy (a , res_type )
23272482
23282483 _manager = dpu .SequentialOrderManager [a_sycl_queue ]
23292484 a_usm_arr = dpnp .get_usm_ndarray (a )
0 commit comments