@@ -398,6 +398,159 @@ def _batched_lu_factor(a, res_type):
398
398
return (out_a , out_ipiv , out_dev_info )
399
399
400
400
401
+ def _batched_lu_factor_scipy (a , res_type ):
402
+ """SciPy-compatible LU factorization for batched inputs."""
403
+
404
+ # TODO: Find out at which array sizes the best performance is obtained
405
+ # getrf_batch implementation shows slow results with large arrays on GPU.
406
+ # Use getrf_batch only on CPU.
407
+ # On GPU call getrf for each two-dimensional array by loop
408
+ use_batch = a .sycl_device .has_aspect_cpu
409
+
410
+ a_sycl_queue = a .sycl_queue
411
+ a_usm_type = a .usm_type
412
+ _manager = dpu .SequentialOrderManager [a_sycl_queue ]
413
+
414
+ m , n = a .shape [- 2 :]
415
+ k = min (m , n )
416
+ orig_shape = a .shape
417
+ # get 3d input arrays by reshape
418
+ a = dpnp .reshape (a , (- 1 , m , n ))
419
+ batch_size = a .shape [0 ]
420
+
421
+ if use_batch :
422
+ # Reorder the elements by moving the last two axes of `a` to the front
423
+ # to match fortran-like array order which is assumed by getrf_batch
424
+ a = dpnp .moveaxis (a , 0 , - 1 )
425
+
426
+ a_usm_arr = dpnp .get_usm_ndarray (a )
427
+
428
+ # `a` must be copied because getrf_batch destroys the input matrix
429
+ a_h = dpnp .empty_like (a , order = "F" , dtype = res_type )
430
+ ipiv_h = dpnp .empty (
431
+ (batch_size , k ),
432
+ dtype = dpnp .int64 ,
433
+ order = "C" ,
434
+ usm_type = a_usm_type ,
435
+ sycl_queue = a_sycl_queue ,
436
+ )
437
+
438
+ dev_info_h = [0 ] * batch_size
439
+
440
+ ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
441
+ src = a_usm_arr ,
442
+ dst = a_h .get_array (),
443
+ sycl_queue = a_sycl_queue ,
444
+ depends = _manager .submitted_events ,
445
+ )
446
+ _manager .add_event_pair (ht_ev , copy_ev )
447
+
448
+ ipiv_stride = k
449
+ a_stride = a_h .strides [- 1 ]
450
+
451
+ # Call the LAPACK extension function _getrf_batch
452
+ # to perform LU decomposition of a batch of general matrices
453
+ ht_ev , getrf_ev = li ._getrf_batch (
454
+ a_sycl_queue ,
455
+ a_h .get_array (),
456
+ ipiv_h .get_array (),
457
+ dev_info_h ,
458
+ m ,
459
+ n ,
460
+ a_stride ,
461
+ ipiv_stride ,
462
+ batch_size ,
463
+ depends = [copy_ev ],
464
+ )
465
+ _manager .add_event_pair (ht_ev , getrf_ev )
466
+
467
+ # getrf_batch expects `a` in Fortran order and overwrites it.
468
+ # Batch was moved to the last axis before the call.
469
+ # Move it back to the front and reshape to the original shape.
470
+ a_h = dpnp .moveaxis (a_h , - 1 , 0 ).reshape (orig_shape )
471
+ ipiv_h = ipiv_h .reshape ((* orig_shape [:- 2 ], k ))
472
+
473
+ if any (dev_info_h ):
474
+ diag_nums = ", " .join (str (v ) for v in dev_info_h if v > 0 )
475
+ warn (
476
+ f"Diagonal number { diag_nums } are exactly zero. "
477
+ "Singular matrix." ,
478
+ RuntimeWarning ,
479
+ stacklevel = 2 ,
480
+ )
481
+
482
+ # MKL lapack uses 1-origin while SciPy uses 0-origin
483
+ ipiv_h -= 1
484
+
485
+ # Return a tuple containing the factorized matrix 'a_h',
486
+ # pivot indices 'ipiv_h'
487
+ return (a_h , ipiv_h )
488
+
489
+ a_usm_arr = dpnp .get_usm_ndarray (a )
490
+
491
+ # Initialize lists for storing arrays and events for each batch
492
+ a_vecs = [None ] * batch_size
493
+ ipiv_vecs = [None ] * batch_size
494
+ dev_info_vecs = [None ] * batch_size
495
+ dep_evs = _manager .submitted_events
496
+
497
+ # Process each batch
498
+ for i in range (batch_size ):
499
+ # Copy each 2D slice to a new array because getrf will destroy
500
+ # the input matrix
501
+ a_vecs [i ] = dpnp .empty_like (a [i ], order = "F" , dtype = res_type )
502
+ ht_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
503
+ src = a_usm_arr [i ],
504
+ dst = a_vecs [i ].get_array (),
505
+ sycl_queue = a_sycl_queue ,
506
+ depends = dep_evs ,
507
+ )
508
+ _manager .add_event_pair (ht_ev , copy_ev )
509
+
510
+ ipiv_vecs [i ] = dpnp .empty (
511
+ (k ,),
512
+ dtype = dpnp .int64 ,
513
+ order = "C" ,
514
+ usm_type = a_usm_type ,
515
+ sycl_queue = a_sycl_queue ,
516
+ )
517
+
518
+ dev_info_vecs [i ] = [0 ]
519
+
520
+ # Call the LAPACK extension function _getrf
521
+ # to perform LU decomposition on each batch in 'a_vecs[i]'
522
+ ht_ev , getrf_ev = li ._getrf (
523
+ a_sycl_queue ,
524
+ a_vecs [i ].get_array (),
525
+ ipiv_vecs [i ].get_array (),
526
+ dev_info_vecs [i ],
527
+ depends = [copy_ev ],
528
+ )
529
+ _manager .add_event_pair (ht_ev , getrf_ev )
530
+
531
+ # Reshape the results back to their original shape
532
+ out_a = dpnp .array (a_vecs ).reshape (orig_shape )
533
+ out_ipiv = dpnp .array (ipiv_vecs ).reshape ((* orig_shape [:- 2 ], k ))
534
+
535
+ diag_nums = ", " .join (
536
+ str (v ) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0
537
+ )
538
+
539
+ if diag_nums :
540
+ warn (
541
+ f"Diagonal number { diag_nums } are exactly zero. Singular matrix." ,
542
+ RuntimeWarning ,
543
+ stacklevel = 2 ,
544
+ )
545
+
546
+ # MKL lapack uses 1-origin while SciPy uses 0-origin
547
+ out_ipiv -= 1
548
+
549
+ # Return a tuple containing the factorized matrix 'out_a',
550
+ # pivot indices 'out_ipiv'
551
+ return (out_a , out_ipiv )
552
+
553
+
401
554
def _batched_solve (a , b , exec_q , res_usm_type , res_type ):
402
555
"""
403
556
_batched_solve(a, b, exec_q, res_usm_type, res_type)
@@ -2323,7 +2476,9 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
2323
2476
raise ValueError ("array must not contain infs or NaNs" )
2324
2477
2325
2478
if a .ndim > 2 :
2326
- raise NotImplementedError ("Batched matrices are not supported" )
2479
+ # SciPy always copies each 2D slice,
2480
+ # so `overwrite_a` is ignored here
2481
+ return _batched_lu_factor_scipy (a , res_type )
2327
2482
2328
2483
_manager = dpu .SequentialOrderManager [a_sycl_queue ]
2329
2484
a_usm_arr = dpnp .get_usm_ndarray (a )
0 commit comments