@@ -402,9 +402,9 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
402
402
"""SciPy-compatible LU factorization for batched inputs."""
403
403
404
404
# TODO: Find out at which array sizes the best performance is obtained
405
- # getrf_batch implementation shows slow results with large arrays on GPU.
405
+ # getrf_batch can be slow on large GPU arrays .
406
406
# Use getrf_batch only on CPU.
407
- # On GPU call getrf for each two-dimensional array by loop
407
+ # On GPU fall back to calling getrf per 2D slice.
408
408
use_batch = a .sycl_device .has_aspect_cpu
409
409
410
410
a_sycl_queue = a .sycl_queue
@@ -416,7 +416,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
416
416
orig_shape = a .shape
417
417
batch_shape = orig_shape [:- 2 ]
418
418
419
- # accommodate empty arrays
419
+ # handle empty input
420
420
if a .size == 0 :
421
421
lu = dpnp .empty_like (a )
422
422
piv = dpnp .empty (
@@ -431,32 +431,33 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
431
431
a = dpnp .reshape (a , (- 1 , m , n ))
432
432
batch_size = a .shape [0 ]
433
433
434
- if use_batch :
435
- # Reorder the elements by moving the last two axes of `a` to the front
436
- # to match fortran-like array order which is assumed by getrf_batch
437
- a = dpnp .moveaxis (a , 0 , - 1 )
434
+ # Move batch axis to the end (m, n, batch) in Fortran order :
435
+ # required by getrf_batch
436
+ # and ensures each a[..., i] is F-contiguous for getrf
437
+ a = dpnp .moveaxis (a , 0 , - 1 )
438
438
439
- a_usm_arr = dpnp .get_usm_ndarray (a )
439
+ a_usm_arr = dpnp .get_usm_ndarray (a )
440
440
441
- # `a` must be copied because getrf_batch destroys the input matrix
442
- a_h = dpnp .empty_like (a , order = "F" , dtype = res_type )
443
- ipiv_h = dpnp . empty (
444
- ( batch_size , k ) ,
445
- dtype = dpnp . int64 ,
446
- order = "C" ,
447
- usm_type = a_usm_type ,
448
- sycl_queue = a_sycl_queue ,
449
- )
441
+ # `a` must be copied because getrf/ getrf_batch destroys the input matrix
442
+ a_h = dpnp .empty_like (a , order = "F" , dtype = res_type )
443
+ ht_ev , copy_ev = ti . _copy_usm_ndarray_into_usm_ndarray (
444
+ src = a_usm_arr ,
445
+ dst = a_h . get_array () ,
446
+ sycl_queue = a_sycl_queue ,
447
+ depends = _manager . submitted_events ,
448
+ )
449
+ _manager . add_event_pair ( ht_ev , copy_ev )
450
450
451
- dev_info_h = [0 ] * batch_size
451
+ ipiv_h = dpnp .empty (
452
+ (batch_size , k ),
453
+ dtype = dpnp .int64 ,
454
+ order = "C" ,
455
+ usm_type = a_usm_type ,
456
+ sycl_queue = a_sycl_queue ,
457
+ )
452
458
453
- ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
454
- src = a_usm_arr ,
455
- dst = a_h .get_array (),
456
- sycl_queue = a_sycl_queue ,
457
- depends = _manager .submitted_events ,
458
- )
459
- _manager .add_event_pair (ht_ev , copy_ev )
459
+ if use_batch :
460
+ dev_info_h = [0 ] * batch_size
460
461
461
462
ipiv_stride = k
462
463
a_stride = a_h .strides [- 1 ]
@@ -477,12 +478,6 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
477
478
)
478
479
_manager .add_event_pair (ht_ev , getrf_ev )
479
480
480
- # getrf_batch expects `a` in Fortran order and overwrites it.
481
- # Batch was moved to the last axis before the call.
482
- # Move it back to the front and reshape to the original shape.
483
- a_h = dpnp .moveaxis (a_h , - 1 , 0 ).reshape (orig_shape )
484
- ipiv_h = ipiv_h .reshape ((* batch_shape , k ))
485
-
486
481
if any (dev_info_h ):
487
482
diag_nums = ", " .join (str (v ) for v in dev_info_h if v > 0 )
488
483
warn (
@@ -491,77 +486,41 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
491
486
RuntimeWarning ,
492
487
stacklevel = 2 ,
493
488
)
489
+ else :
490
+ dev_info_vecs = [[0 ] for _ in range (batch_size )]
491
+
492
+ # Sequential LU factorization using getrf per slice
493
+ for i in range (batch_size ):
494
+ ht_ev , getrf_ev = li ._getrf (
495
+ a_sycl_queue ,
496
+ a_h [..., i ].get_array (),
497
+ ipiv_h [i ].get_array (),
498
+ dev_info_vecs [i ],
499
+ depends = [copy_ev ],
500
+ )
501
+ _manager .add_event_pair (ht_ev , getrf_ev )
494
502
495
- # MKL lapack uses 1-origin while SciPy uses 0-origin
496
- ipiv_h -= 1
497
-
498
- # Return a tuple containing the factorized matrix 'a_h',
499
- # pivot indices 'ipiv_h'
500
- return (a_h , ipiv_h )
501
-
502
- a_usm_arr = dpnp .get_usm_ndarray (a )
503
-
504
- # Initialize lists for storing arrays and events for each batch
505
- a_vecs = [None ] * batch_size
506
- ipiv_vecs = [None ] * batch_size
507
- dev_info_vecs = [None ] * batch_size
508
- dep_evs = _manager .submitted_events
509
-
510
- # Process each batch
511
- for i in range (batch_size ):
512
- # Copy each 2D slice to a new array because getrf will destroy
513
- # the input matrix
514
- a_vecs [i ] = dpnp .empty_like (a [i ], order = "F" , dtype = res_type )
515
- ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
516
- src = a_usm_arr [i ],
517
- dst = a_vecs [i ].get_array (),
518
- sycl_queue = a_sycl_queue ,
519
- depends = dep_evs ,
520
- )
521
- _manager .add_event_pair (ht_ev , copy_ev )
522
-
523
- ipiv_vecs [i ] = dpnp .empty (
524
- (k ,),
525
- dtype = dpnp .int64 ,
526
- order = "C" ,
527
- usm_type = a_usm_type ,
528
- sycl_queue = a_sycl_queue ,
503
+ diag_nums = ", " .join (
504
+ str (v ) for info in dev_info_vecs for v in info if v > 0
529
505
)
506
+ if diag_nums :
507
+ warn (
508
+ f"Diagonal number { diag_nums } are exactly zero. "
509
+ "Singular matrix." ,
510
+ RuntimeWarning ,
511
+ stacklevel = 2 ,
512
+ )
530
513
531
- dev_info_vecs [i ] = [0 ]
532
-
533
- # Call the LAPACK extension function _getrf
534
- # to perform LU decomposition on each batch in 'a_vecs[i]'
535
- ht_ev , getrf_ev = li ._getrf (
536
- a_sycl_queue ,
537
- a_vecs [i ].get_array (),
538
- ipiv_vecs [i ].get_array (),
539
- dev_info_vecs [i ],
540
- depends = [copy_ev ],
541
- )
542
- _manager .add_event_pair (ht_ev , getrf_ev )
543
-
544
- # Reshape the results back to their original shape
545
- out_a = dpnp .array (a_vecs ).reshape (orig_shape )
546
- out_ipiv = dpnp .array (ipiv_vecs ).reshape ((* batch_shape , k ))
547
-
548
- diag_nums = ", " .join (
549
- str (v ) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0
550
- )
551
-
552
- if diag_nums :
553
- warn (
554
- f"Diagonal number { diag_nums } are exactly zero. Singular matrix." ,
555
- RuntimeWarning ,
556
- stacklevel = 2 ,
557
- )
514
+ # Restore original shape: move batch axis back and reshape
515
+ a_h = dpnp .moveaxis (a_h , - 1 , 0 ).reshape (orig_shape )
516
+ ipiv_h = ipiv_h .reshape ((* batch_shape , k ))
558
517
559
- # MKL lapack uses 1-origin while SciPy uses 0-origin
560
- out_ipiv -= 1
518
+ # oneMKL LAPACK uses 1-origin while SciPy uses 0-origin
519
+ ipiv_h -= 1
561
520
562
- # Return a tuple containing the factorized matrix 'out_a ',
563
- # pivot indices 'out_ipiv '
564
- return (out_a , out_ipiv )
521
+ # Return a tuple containing the factorized matrix 'a_h ',
522
+ # pivot indices 'ipiv_h '
523
+ return (a_h , ipiv_h )
565
524
566
525
567
526
def _batched_solve (a , b , exec_q , res_usm_type , res_type ):
0 commit comments