@@ -402,9 +402,9 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
402402 """SciPy-compatible LU factorization for batched inputs."""
403403
404404 # TODO: Find out at which array sizes the best performance is obtained
405- # getrf_batch implementation shows slow results with large arrays on GPU.
405+ # getrf_batch can be slow on large GPU arrays .
406406 # Use getrf_batch only on CPU.
407- # On GPU call getrf for each two-dimensional array by loop
407+ # On GPU fall back to calling getrf per 2D slice.
408408 use_batch = a .sycl_device .has_aspect_cpu
409409
410410 a_sycl_queue = a .sycl_queue
@@ -416,7 +416,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
416416 orig_shape = a .shape
417417 batch_shape = orig_shape [:- 2 ]
418418
419- # accommodate empty arrays
419+ # handle empty input
420420 if a .size == 0 :
421421 lu = dpnp .empty_like (a )
422422 piv = dpnp .empty (
@@ -431,32 +431,33 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
431431 a = dpnp .reshape (a , (- 1 , m , n ))
432432 batch_size = a .shape [0 ]
433433
434- if use_batch :
435- # Reorder the elements by moving the last two axes of `a` to the front
436- # to match fortran-like array order which is assumed by getrf_batch
437- a = dpnp .moveaxis (a , 0 , - 1 )
434+ # Move batch axis to the end (m, n, batch) in Fortran order :
435+ # required by getrf_batch
436+ # and ensures each a[..., i] is F-contiguous for getrf
437+ a = dpnp .moveaxis (a , 0 , - 1 )
438438
439- a_usm_arr = dpnp .get_usm_ndarray (a )
439+ a_usm_arr = dpnp .get_usm_ndarray (a )
440440
441- # `a` must be copied because getrf_batch destroys the input matrix
442- a_h = dpnp .empty_like (a , order = "F" , dtype = res_type )
443- ipiv_h = dpnp . empty (
444- ( batch_size , k ) ,
445- dtype = dpnp . int64 ,
446- order = "C" ,
447- usm_type = a_usm_type ,
448- sycl_queue = a_sycl_queue ,
449- )
441+ # `a` must be copied because getrf/ getrf_batch destroys the input matrix
442+ a_h = dpnp .empty_like (a , order = "F" , dtype = res_type )
443+ ht_ev , copy_ev = ti . _copy_usm_ndarray_into_usm_ndarray (
444+ src = a_usm_arr ,
445+ dst = a_h . get_array () ,
446+ sycl_queue = a_sycl_queue ,
447+ depends = _manager . submitted_events ,
448+ )
449+ _manager . add_event_pair ( ht_ev , copy_ev )
450450
451- dev_info_h = [0 ] * batch_size
451+ ipiv_h = dpnp .empty (
452+ (batch_size , k ),
453+ dtype = dpnp .int64 ,
454+ order = "C" ,
455+ usm_type = a_usm_type ,
456+ sycl_queue = a_sycl_queue ,
457+ )
452458
453- ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
454- src = a_usm_arr ,
455- dst = a_h .get_array (),
456- sycl_queue = a_sycl_queue ,
457- depends = _manager .submitted_events ,
458- )
459- _manager .add_event_pair (ht_ev , copy_ev )
459+ if use_batch :
460+ dev_info_h = [0 ] * batch_size
460461
461462 ipiv_stride = k
462463 a_stride = a_h .strides [- 1 ]
@@ -477,12 +478,6 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
477478 )
478479 _manager .add_event_pair (ht_ev , getrf_ev )
479480
480- # getrf_batch expects `a` in Fortran order and overwrites it.
481- # Batch was moved to the last axis before the call.
482- # Move it back to the front and reshape to the original shape.
483- a_h = dpnp .moveaxis (a_h , - 1 , 0 ).reshape (orig_shape )
484- ipiv_h = ipiv_h .reshape ((* batch_shape , k ))
485-
486481 if any (dev_info_h ):
487482 diag_nums = ", " .join (str (v ) for v in dev_info_h if v > 0 )
488483 warn (
@@ -491,77 +486,41 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
491486 RuntimeWarning ,
492487 stacklevel = 2 ,
493488 )
489+ else :
490+ dev_info_vecs = [[0 ] for _ in range (batch_size )]
491+
492+ # Sequential LU factorization using getrf per slice
493+ for i in range (batch_size ):
494+ ht_ev , getrf_ev = li ._getrf (
495+ a_sycl_queue ,
496+ a_h [..., i ].get_array (),
497+ ipiv_h [i ].get_array (),
498+ dev_info_vecs [i ],
499+ depends = [copy_ev ],
500+ )
501+ _manager .add_event_pair (ht_ev , getrf_ev )
494502
495- # MKL lapack uses 1-origin while SciPy uses 0-origin
496- ipiv_h -= 1
497-
498- # Return a tuple containing the factorized matrix 'a_h',
499- # pivot indices 'ipiv_h'
500- return (a_h , ipiv_h )
501-
502- a_usm_arr = dpnp .get_usm_ndarray (a )
503-
504- # Initialize lists for storing arrays and events for each batch
505- a_vecs = [None ] * batch_size
506- ipiv_vecs = [None ] * batch_size
507- dev_info_vecs = [None ] * batch_size
508- dep_evs = _manager .submitted_events
509-
510- # Process each batch
511- for i in range (batch_size ):
512- # Copy each 2D slice to a new array because getrf will destroy
513- # the input matrix
514- a_vecs [i ] = dpnp .empty_like (a [i ], order = "F" , dtype = res_type )
515- ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
516- src = a_usm_arr [i ],
517- dst = a_vecs [i ].get_array (),
518- sycl_queue = a_sycl_queue ,
519- depends = dep_evs ,
520- )
521- _manager .add_event_pair (ht_ev , copy_ev )
522-
523- ipiv_vecs [i ] = dpnp .empty (
524- (k ,),
525- dtype = dpnp .int64 ,
526- order = "C" ,
527- usm_type = a_usm_type ,
528- sycl_queue = a_sycl_queue ,
503+ diag_nums = ", " .join (
504+ str (v ) for info in dev_info_vecs for v in info if v > 0
529505 )
506+ if diag_nums :
507+ warn (
508+ f"Diagonal number { diag_nums } are exactly zero. "
509+ "Singular matrix." ,
510+ RuntimeWarning ,
511+ stacklevel = 2 ,
512+ )
530513
531- dev_info_vecs [i ] = [0 ]
532-
533- # Call the LAPACK extension function _getrf
534- # to perform LU decomposition on each batch in 'a_vecs[i]'
535- ht_ev , getrf_ev = li ._getrf (
536- a_sycl_queue ,
537- a_vecs [i ].get_array (),
538- ipiv_vecs [i ].get_array (),
539- dev_info_vecs [i ],
540- depends = [copy_ev ],
541- )
542- _manager .add_event_pair (ht_ev , getrf_ev )
543-
544- # Reshape the results back to their original shape
545- out_a = dpnp .array (a_vecs ).reshape (orig_shape )
546- out_ipiv = dpnp .array (ipiv_vecs ).reshape ((* batch_shape , k ))
547-
548- diag_nums = ", " .join (
549- str (v ) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0
550- )
551-
552- if diag_nums :
553- warn (
554- f"Diagonal number { diag_nums } are exactly zero. Singular matrix." ,
555- RuntimeWarning ,
556- stacklevel = 2 ,
557- )
514+ # Restore original shape: move batch axis back and reshape
515+ a_h = dpnp .moveaxis (a_h , - 1 , 0 ).reshape (orig_shape )
516+ ipiv_h = ipiv_h .reshape ((* batch_shape , k ))
558517
559- # MKL lapack uses 1-origin while SciPy uses 0-origin
560- out_ipiv -= 1
518+ # oneMKL LAPACK uses 1-origin while SciPy uses 0-origin
519+ ipiv_h -= 1
561520
562- # Return a tuple containing the factorized matrix 'out_a ',
563- # pivot indices 'out_ipiv '
564- return (out_a , out_ipiv )
521+ # Return a tuple containing the factorized matrix 'a_h ',
522+ # pivot indices 'ipiv_h '
523+ return (a_h , ipiv_h )
565524
566525
567526def _batched_solve (a , b , exec_q , res_usm_type , res_type ):
0 commit comments