@@ -107,6 +107,37 @@ class SVDResult(NamedTuple):
107107}
108108
109109
110+ def _align_lu_solve_broadcast (lu , b ):
111+ """Align LU and RHS batch dimensions with SciPy-like rules."""
112+ lu_shape = lu .shape
113+ b_shape = b .shape
114+
115+ if b .ndim < 2 :
116+ if lu_shape [- 2 ] != b_shape [0 ]:
117+ raise ValueError (
118+ f"Shapes of lu { lu_shape } and b { b_shape } are incompatible"
119+ )
120+ b = dpnp .broadcast_to (b , lu_shape [:- 1 ])
121+ return lu , b
122+
123+ if lu_shape [- 2 ] != b_shape [- 2 ]:
124+ raise ValueError (
125+ f"Shapes of lu { lu_shape } and b { b_shape } are incompatible"
126+ )
127+
128+ # Use dpnp.broadcast_shapes() to align the resulting batch shapes
129+ batch = dpnp .broadcast_shapes (lu_shape [:- 2 ], b_shape [:- 2 ])
130+ lu_bshape = batch + lu_shape [- 2 :]
131+ b_bshape = batch + b_shape [- 2 :]
132+
133+ if lu_shape != lu_bshape :
134+ lu = dpnp .broadcast_to (lu , lu_bshape )
135+ if b_shape != b_bshape :
136+ b = dpnp .broadcast_to (b , b_bshape )
137+
138+ return lu , b
139+
140+
110141def _batched_eigh (a , UPLO , eigen_mode , w_type , v_type ):
111142 """
112143 _batched_eigh(a, UPLO, eigen_mode, w_type, v_type)
@@ -486,6 +517,109 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
486517 return (a_h , ipiv_h )
487518
488519
520+ def _batched_lu_solve (lu , piv , b , res_type , trans = 0 ):
521+ """Solve a batched equation system (SciPy-compatible behavior)."""
522+ res_usm_type , exec_q = get_usm_allocations ([lu , piv , b ])
523+
524+ if b .size == 0 :
525+ return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
526+
527+ b_ndim = b .ndim
528+
529+ lu , b = _align_lu_solve_broadcast (lu , b )
530+
531+ n = lu .shape [- 1 ]
532+ nrhs = b .shape [- 1 ] if b_ndim > 1 else 1
533+
534+ # get 3d input arrays by reshape
535+ if lu .ndim > 3 :
536+ lu = dpnp .reshape (lu , (- 1 , n , n ))
537+ # get 2d pivot arrays by reshape
538+ if piv .ndim > 2 :
539+ piv = dpnp .reshape (piv , (- 1 , n ))
540+ batch_size = lu .shape [0 ]
541+
542+ # Move batch axis to the end (n, n, batch) in Fortran order:
543+ # required by getrs_batch
544+ # and ensures each a[..., i] is F-contiguous for getrs_batch
545+ lu = dpnp .moveaxis (lu , 0 , - 1 )
546+
547+ b_orig_shape = b .shape
548+ if b .ndim > 2 :
549+ b = dpnp .reshape (b , (- 1 , n , nrhs ))
550+
551+ # Move batch axis to the end (n, nrhs, batch) in Fortran order:
552+ # required by getrs_batch
553+ # and ensures each b[..., i] is F-contiguous for getrs_batch
554+ b = dpnp .moveaxis (b , 0 , - 1 )
555+
556+ lu_usm_arr = dpnp .get_usm_ndarray (lu )
557+ b_usm_arr = dpnp .get_usm_ndarray (b )
558+
559+ # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
560+ # convert to 1-based for oneMKL getrs_batch
561+ piv_h = piv + 1
562+
563+ _manager = dpu .SequentialOrderManager [exec_q ]
564+ dep_evs = _manager .submitted_events
565+
566+ # oneMKL LAPACK getrs overwrites `lu`.
567+ lu_h = dpnp .empty_like (lu , order = "F" , dtype = res_type , usm_type = res_usm_type )
568+
569+ # use DPCTL tensor function to fill the сopy of the input array
570+ # from the input array
571+ ht_ev , lu_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
572+ src = lu_usm_arr ,
573+ dst = lu_h .get_array (),
574+ sycl_queue = lu .sycl_queue ,
575+ depends = dep_evs ,
576+ )
577+ _manager .add_event_pair (ht_ev , lu_copy_ev )
578+
579+ b_h = dpnp .empty_like (b , order = "F" , dtype = res_type , usm_type = res_usm_type )
580+ ht_ev , b_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
581+ src = b_usm_arr ,
582+ dst = b_h .get_array (),
583+ sycl_queue = b .sycl_queue ,
584+ depends = dep_evs ,
585+ )
586+ _manager .add_event_pair (ht_ev , b_copy_ev )
587+ dep_evs = [lu_copy_ev , b_copy_ev ]
588+
589+ lu_stride = lu_h .strides [- 1 ]
590+ piv_stride = piv .strides [0 ]
591+ b_stride = b_h .strides [- 1 ]
592+
593+ if not isinstance (trans , int ):
594+ raise TypeError ("`trans` must be an integer" )
595+
596+ trans_mkl = _map_trans_to_mkl (trans )
597+
598+ # Call the LAPACK extension function _getrs_batch
599+ # to solve the system of linear equations with an LU-factored
600+ # coefficient square matrix, with multiple right-hand sides.
601+ ht_ev , getrs_batch_ev = li ._getrs_batch (
602+ exec_q ,
603+ lu_h .get_array (),
604+ piv_h .get_array (),
605+ b_h .get_array (),
606+ trans_mkl ,
607+ n ,
608+ nrhs ,
609+ lu_stride ,
610+ piv_stride ,
611+ b_stride ,
612+ batch_size ,
613+ depends = dep_evs ,
614+ )
615+ _manager .add_event_pair (ht_ev , getrs_batch_ev )
616+
617+ # Restore original shape: move batch axis back and reshape
618+ b_h = dpnp .moveaxis (b_h , - 1 , 0 ).reshape (b_orig_shape )
619+
620+ return b_h
621+
622+
489623def _batched_solve (a , b , exec_q , res_usm_type , res_type ):
490624 """
491625 _batched_solve(a, b, exec_q, res_usm_type, res_type)
@@ -1099,6 +1233,20 @@ def _is_empty_2d(arr):
10991233 return arr .size == 0 and numpy .prod (arr .shape [- 2 :]) == 0
11001234
11011235
1236+ def _map_trans_to_mkl (trans ):
1237+ """Map SciPy-style trans code (0,1,2) to oneMKL transpose enum."""
1238+ if not isinstance (trans , int ):
1239+ raise TypeError ("`trans` must be an integer" )
1240+
1241+ if trans == 0 :
1242+ return li .Transpose .N
1243+ if trans == 1 :
1244+ return li .Transpose .T
1245+ if trans == 2 :
1246+ return li .Transpose .C
1247+ raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
1248+
1249+
11021250def _lu_factor (a , res_type ):
11031251 """
11041252 Compute pivoted LU decomposition.
@@ -2493,18 +2641,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
24932641
24942642 res_type = _common_type (lu , b )
24952643
2496- # TODO: add broadcasting
2497- if lu .shape [0 ] != b .shape [0 ]:
2498- raise ValueError (
2499- f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
2500- )
2501-
25022644 if b .size == 0 :
25032645 return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
25042646
2505- if lu .ndim > 2 :
2506- raise NotImplementedError ("Batched matrices are not supported" )
2507-
25082647 if check_finite :
25092648 if not dpnp .isfinite (lu ).all ():
25102649 raise ValueError (
@@ -2517,6 +2656,16 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
25172656 "Right-hand side array must not contain infs or NaNs"
25182657 )
25192658
2659+ if lu .ndim > 2 :
2660+ # SciPy always copies each 2D slice,
2661+ # so `overwrite_b` is ignored here
2662+ return _batched_lu_solve (lu , piv , b , trans = trans , res_type = res_type )
2663+
2664+ if lu .shape [0 ] != b .shape [0 ]:
2665+ raise ValueError (
2666+ f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
2667+ )
2668+
25202669 lu_usm_arr = dpnp .get_usm_ndarray (lu )
25212670 b_usm_arr = dpnp .get_usm_ndarray (b )
25222671
@@ -2563,18 +2712,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
25632712 b_h = b
25642713 dep_evs = [lu_copy_ev ]
25652714
2566- if not isinstance (trans , int ):
2567- raise TypeError ("`trans` must be an integer" )
2568-
2569- # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
2570- if trans == 0 :
2571- trans_mkl = li .Transpose .N
2572- elif trans == 1 :
2573- trans_mkl = li .Transpose .T
2574- elif trans == 2 :
2575- trans_mkl = li .Transpose .C
2576- else :
2577- raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
2715+ trans_mkl = _map_trans_to_mkl (trans )
25782716
25792717 # Call the LAPACK extension function _getrs
25802718 # to solve the system of linear equations with an LU-factored
0 commit comments