@@ -107,6 +107,37 @@ class SVDResult(NamedTuple):
107
107
}
108
108
109
109
110
+ def _align_lu_solve_broadcast (lu , b ):
111
+ """Align LU and RHS batch dimensions with SciPy-like rules."""
112
+ lu_shape = lu .shape
113
+ b_shape = b .shape
114
+
115
+ if b .ndim < 2 :
116
+ if lu_shape [- 2 ] != b_shape [0 ]:
117
+ raise ValueError (
118
+ f"Shapes of lu { lu_shape } and b { b_shape } are incompatible"
119
+ )
120
+ b = dpnp .broadcast_to (b , lu_shape [:- 1 ])
121
+ return lu , b
122
+
123
+ if lu_shape [- 2 ] != b_shape [- 2 ]:
124
+ raise ValueError (
125
+ f"Shapes of lu { lu_shape } and b { b_shape } are incompatible"
126
+ )
127
+
128
+ # Use dpnp.broadcast_shapes() to align the resulting batch shapes
129
+ batch = dpnp .broadcast_shapes (lu_shape [:- 2 ], b_shape [:- 2 ])
130
+ lu_bshape = batch + lu_shape [- 2 :]
131
+ b_bshape = batch + b_shape [- 2 :]
132
+
133
+ if lu_shape != lu_bshape :
134
+ lu = dpnp .broadcast_to (lu , lu_bshape )
135
+ if b_shape != b_bshape :
136
+ b = dpnp .broadcast_to (b , b_bshape )
137
+
138
+ return lu , b
139
+
140
+
110
141
def _batched_eigh (a , UPLO , eigen_mode , w_type , v_type ):
111
142
"""
112
143
_batched_eigh(a, UPLO, eigen_mode, w_type, v_type)
@@ -486,6 +517,109 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
486
517
return (a_h , ipiv_h )
487
518
488
519
520
+ def _batched_lu_solve (lu , piv , b , res_type , trans = 0 ):
521
+ """Solve a batched equation system (SciPy-compatible behavior)."""
522
+ res_usm_type , exec_q = get_usm_allocations ([lu , piv , b ])
523
+
524
+ if b .size == 0 :
525
+ return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
526
+
527
+ b_ndim = b .ndim
528
+
529
+ lu , b = _align_lu_solve_broadcast (lu , b )
530
+
531
+ n = lu .shape [- 1 ]
532
+ nrhs = b .shape [- 1 ] if b_ndim > 1 else 1
533
+
534
+ # get 3d input arrays by reshape
535
+ if lu .ndim > 3 :
536
+ lu = dpnp .reshape (lu , (- 1 , n , n ))
537
+ # get 2d pivot arrays by reshape
538
+ if piv .ndim > 2 :
539
+ piv = dpnp .reshape (piv , (- 1 , n ))
540
+ batch_size = lu .shape [0 ]
541
+
542
+ # Move batch axis to the end (n, n, batch) in Fortran order:
543
+ # required by getrs_batch
544
+ # and ensures each a[..., i] is F-contiguous for getrs_batch
545
+ lu = dpnp .moveaxis (lu , 0 , - 1 )
546
+
547
+ b_orig_shape = b .shape
548
+ if b .ndim > 2 :
549
+ b = dpnp .reshape (b , (- 1 , n , nrhs ))
550
+
551
+ # Move batch axis to the end (n, nrhs, batch) in Fortran order:
552
+ # required by getrs_batch
553
+ # and ensures each b[..., i] is F-contiguous for getrs_batch
554
+ b = dpnp .moveaxis (b , 0 , - 1 )
555
+
556
+ lu_usm_arr = dpnp .get_usm_ndarray (lu )
557
+ b_usm_arr = dpnp .get_usm_ndarray (b )
558
+
559
+ # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
560
+ # convert to 1-based for oneMKL getrs_batch
561
+ piv_h = piv + 1
562
+
563
+ _manager = dpu .SequentialOrderManager [exec_q ]
564
+ dep_evs = _manager .submitted_events
565
+
566
+ # oneMKL LAPACK getrs overwrites `lu`.
567
+ lu_h = dpnp .empty_like (lu , order = "F" , dtype = res_type , usm_type = res_usm_type )
568
+
569
+ # use DPCTL tensor function to fill the сopy of the input array
570
+ # from the input array
571
+ ht_ev , lu_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
572
+ src = lu_usm_arr ,
573
+ dst = lu_h .get_array (),
574
+ sycl_queue = lu .sycl_queue ,
575
+ depends = dep_evs ,
576
+ )
577
+ _manager .add_event_pair (ht_ev , lu_copy_ev )
578
+
579
+ b_h = dpnp .empty_like (b , order = "F" , dtype = res_type , usm_type = res_usm_type )
580
+ ht_ev , b_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
581
+ src = b_usm_arr ,
582
+ dst = b_h .get_array (),
583
+ sycl_queue = b .sycl_queue ,
584
+ depends = dep_evs ,
585
+ )
586
+ _manager .add_event_pair (ht_ev , b_copy_ev )
587
+ dep_evs = [lu_copy_ev , b_copy_ev ]
588
+
589
+ lu_stride = lu_h .strides [- 1 ]
590
+ piv_stride = piv .strides [0 ]
591
+ b_stride = b_h .strides [- 1 ]
592
+
593
+ if not isinstance (trans , int ):
594
+ raise TypeError ("`trans` must be an integer" )
595
+
596
+ trans_mkl = _map_trans_to_mkl (trans )
597
+
598
+ # Call the LAPACK extension function _getrs_batch
599
+ # to solve the system of linear equations with an LU-factored
600
+ # coefficient square matrix, with multiple right-hand sides.
601
+ ht_ev , getrs_batch_ev = li ._getrs_batch (
602
+ exec_q ,
603
+ lu_h .get_array (),
604
+ piv_h .get_array (),
605
+ b_h .get_array (),
606
+ trans_mkl ,
607
+ n ,
608
+ nrhs ,
609
+ lu_stride ,
610
+ piv_stride ,
611
+ b_stride ,
612
+ batch_size ,
613
+ depends = dep_evs ,
614
+ )
615
+ _manager .add_event_pair (ht_ev , getrs_batch_ev )
616
+
617
+ # Restore original shape: move batch axis back and reshape
618
+ b_h = dpnp .moveaxis (b_h , - 1 , 0 ).reshape (b_orig_shape )
619
+
620
+ return b_h
621
+
622
+
489
623
def _batched_solve (a , b , exec_q , res_usm_type , res_type ):
490
624
"""
491
625
_batched_solve(a, b, exec_q, res_usm_type, res_type)
@@ -1099,6 +1233,20 @@ def _is_empty_2d(arr):
1099
1233
return arr .size == 0 and numpy .prod (arr .shape [- 2 :]) == 0
1100
1234
1101
1235
1236
+ def _map_trans_to_mkl (trans ):
1237
+ """Map SciPy-style trans code (0,1,2) to oneMKL transpose enum."""
1238
+ if not isinstance (trans , int ):
1239
+ raise TypeError ("`trans` must be an integer" )
1240
+
1241
+ if trans == 0 :
1242
+ return li .Transpose .N
1243
+ if trans == 1 :
1244
+ return li .Transpose .T
1245
+ if trans == 2 :
1246
+ return li .Transpose .C
1247
+ raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
1248
+
1249
+
1102
1250
def _lu_factor (a , res_type ):
1103
1251
"""
1104
1252
Compute pivoted LU decomposition.
@@ -2493,18 +2641,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2493
2641
2494
2642
res_type = _common_type (lu , b )
2495
2643
2496
- # TODO: add broadcasting
2497
- if lu .shape [0 ] != b .shape [0 ]:
2498
- raise ValueError (
2499
- f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
2500
- )
2501
-
2502
2644
if b .size == 0 :
2503
2645
return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
2504
2646
2505
- if lu .ndim > 2 :
2506
- raise NotImplementedError ("Batched matrices are not supported" )
2507
-
2508
2647
if check_finite :
2509
2648
if not dpnp .isfinite (lu ).all ():
2510
2649
raise ValueError (
@@ -2517,6 +2656,16 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2517
2656
"Right-hand side array must not contain infs or NaNs"
2518
2657
)
2519
2658
2659
+ if lu .ndim > 2 :
2660
+ # SciPy always copies each 2D slice,
2661
+ # so `overwrite_b` is ignored here
2662
+ return _batched_lu_solve (lu , piv , b , trans = trans , res_type = res_type )
2663
+
2664
+ if lu .shape [0 ] != b .shape [0 ]:
2665
+ raise ValueError (
2666
+ f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
2667
+ )
2668
+
2520
2669
lu_usm_arr = dpnp .get_usm_ndarray (lu )
2521
2670
b_usm_arr = dpnp .get_usm_ndarray (b )
2522
2671
@@ -2563,18 +2712,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2563
2712
b_h = b
2564
2713
dep_evs = [lu_copy_ev ]
2565
2714
2566
- if not isinstance (trans , int ):
2567
- raise TypeError ("`trans` must be an integer" )
2568
-
2569
- # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
2570
- if trans == 0 :
2571
- trans_mkl = li .Transpose .N
2572
- elif trans == 1 :
2573
- trans_mkl = li .Transpose .T
2574
- elif trans == 2 :
2575
- trans_mkl = li .Transpose .C
2576
- else :
2577
- raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
2715
+ trans_mkl = _map_trans_to_mkl (trans )
2578
2716
2579
2717
# Call the LAPACK extension function _getrs
2580
2718
# to solve the system of linear equations with an LU-factored
0 commit comments