5555]
5656
5757
58+ def _align_lu_solve_broadcast (lu , b ):
59+ """Align LU and RHS batch dimensions with SciPy-like rules."""
60+ lu_shape = lu .shape
61+ b_shape = b .shape
62+
63+ if b .ndim < 2 :
64+ if lu_shape [- 2 ] != b_shape [0 ]:
65+ raise ValueError (
66+ f"Shapes of lu { lu_shape } and b { b_shape } are incompatible"
67+ )
68+ b = dpnp .broadcast_to (b , lu_shape [:- 1 ])
69+ return lu , b
70+
71+ if lu_shape [- 2 ] != b_shape [- 2 ]:
72+ raise ValueError (
73+ f"Shapes of lu { lu_shape } and b { b_shape } are incompatible"
74+ )
75+
76+ # Use dpnp.broadcast_shapes() to align the resulting batch shapes
77+ batch = dpnp .broadcast_shapes (lu_shape [:- 2 ], b_shape [:- 2 ])
78+ lu_bshape = batch + lu_shape [- 2 :]
79+ b_bshape = batch + b_shape [- 2 :]
80+
81+ if lu_shape != lu_bshape :
82+ lu = dpnp .broadcast_to (lu , lu_bshape )
83+ if b_shape != b_bshape :
84+ b = dpnp .broadcast_to (b , b_bshape )
85+
86+ return lu , b
87+
88+
5889def _batched_lu_factor_scipy (a , res_type ): # pylint: disable=too-many-locals
5990 """SciPy-compatible LU factorization for batched inputs."""
6091
@@ -180,6 +211,106 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
180211 return (a_h , ipiv_h )
181212
182213
214+ def _batched_lu_solve (lu , piv , b , res_type , trans = 0 ):
215+ """Solve a batched equation system (SciPy-compatible behavior)."""
216+ res_usm_type , exec_q = get_usm_allocations ([lu , piv , b ])
217+
218+ if b .size == 0 :
219+ return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
220+
221+ b_ndim = b .ndim
222+
223+ lu , b = _align_lu_solve_broadcast (lu , b )
224+
225+ n = lu .shape [- 1 ]
226+ nrhs = b .shape [- 1 ] if b_ndim > 1 else 1
227+
228+ # get 3d input arrays by reshape
229+ if lu .ndim > 3 :
230+ lu = dpnp .reshape (lu , (- 1 , n , n ))
231+ # get 2d pivot arrays by reshape
232+ if piv .ndim > 2 :
233+ piv = dpnp .reshape (piv , (- 1 , n ))
234+ batch_size = lu .shape [0 ]
235+
236+ # Move batch axis to the end (n, n, batch) in Fortran order:
237+ # required by getrs_batch
238+ # and ensures each a[..., i] is F-contiguous for getrs_batch
239+ lu = dpnp .moveaxis (lu , 0 , - 1 )
240+
241+ b_orig_shape = b .shape
242+ if b .ndim > 2 :
243+ b = dpnp .reshape (b , (- 1 , n , nrhs ))
244+
245+ # Move batch axis to the end (n, nrhs, batch) in Fortran order:
246+ # required by getrs_batch
247+ # and ensures each b[..., i] is F-contiguous for getrs_batch
248+ b = dpnp .moveaxis (b , 0 , - 1 )
249+
250+ lu_usm_arr = dpnp .get_usm_ndarray (lu )
251+ b_usm_arr = dpnp .get_usm_ndarray (b )
252+
253+ # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
254+ # convert to 1-based for oneMKL getrs_batch
255+ piv_h = piv + 1
256+
257+ _manager = dpu .SequentialOrderManager [exec_q ]
258+ dep_evs = _manager .submitted_events
259+
260+ # oneMKL LAPACK getrs overwrites `lu`.
261+ lu_h = dpnp .empty_like (lu , order = "F" , dtype = res_type , usm_type = res_usm_type )
262+
263+ # use DPCTL tensor function to fill the сopy of the input array
264+ # from the input array
265+ ht_ev , lu_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
266+ src = lu_usm_arr ,
267+ dst = lu_h .get_array (),
268+ sycl_queue = lu .sycl_queue ,
269+ depends = dep_evs ,
270+ )
271+ _manager .add_event_pair (ht_ev , lu_copy_ev )
272+
273+ b_h = dpnp .empty_like (b , order = "F" , dtype = res_type , usm_type = res_usm_type )
274+ ht_ev , b_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
275+ src = b_usm_arr ,
276+ dst = b_h .get_array (),
277+ sycl_queue = b .sycl_queue ,
278+ depends = dep_evs ,
279+ )
280+ _manager .add_event_pair (ht_ev , b_copy_ev )
281+ dep_evs = [lu_copy_ev , b_copy_ev ]
282+
283+ lu_stride = n * n
284+ piv_stride = n
285+ b_stride = n * nrhs
286+
287+ trans_mkl = _map_trans_to_mkl (trans )
288+
289+ # Call the LAPACK extension function _getrs_batch
290+ # to solve the system of linear equations with an LU-factored
291+ # coefficient square matrix, with multiple right-hand sides.
292+ ht_ev , getrs_batch_ev = li ._getrs_batch (
293+ exec_q ,
294+ lu_h .get_array (),
295+ piv_h .get_array (),
296+ b_h .get_array (),
297+ trans_mkl ,
298+ n ,
299+ nrhs ,
300+ lu_stride ,
301+ piv_stride ,
302+ b_stride ,
303+ batch_size ,
304+ depends = dep_evs ,
305+ )
306+ _manager .add_event_pair (ht_ev , getrs_batch_ev )
307+
308+ # Restore original shape: move batch axis back and reshape
309+ b_h = dpnp .moveaxis (b_h , - 1 , 0 ).reshape (b_orig_shape )
310+
311+ return b_h
312+
313+
183314def _is_copy_required (a , res_type ):
184315 """
185316 Determine if `a` needs to be copied before LU decomposition.
@@ -197,6 +328,20 @@ def _is_copy_required(a, res_type):
197328 return False
198329
199330
331+ def _map_trans_to_mkl (trans ):
332+ """Map SciPy-style trans code (0,1,2) to oneMKL transpose enum."""
333+ if not isinstance (trans , int ):
334+ raise TypeError ("`trans` must be an integer" )
335+
336+ if trans == 0 :
337+ return li .Transpose .N
338+ if trans == 1 :
339+ return li .Transpose .T
340+ if trans == 2 :
341+ return li .Transpose .C
342+ raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
343+
344+
200345def dpnp_lu_factor (a , overwrite_a = False , check_finite = True ):
201346 """
202347 dpnp_lu_factor(a, overwrite_a=False, check_finite=True)
@@ -307,18 +452,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
307452
308453 res_type = _common_type (lu , b )
309454
310- # TODO: add broadcasting
311- if lu .shape [0 ] != b .shape [0 ]:
312- raise ValueError (
313- f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
314- )
315-
316455 if b .size == 0 :
317456 return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
318457
319- if lu .ndim > 2 :
320- raise NotImplementedError ("Batched matrices are not supported" )
321-
322458 if check_finite :
323459 if not dpnp .isfinite (lu ).all ():
324460 raise ValueError (
@@ -331,6 +467,16 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
331467 "Right-hand side array must not contain infs or NaNs"
332468 )
333469
470+ if lu .ndim > 2 :
471+ # SciPy always copies each 2D slice,
472+ # so `overwrite_b` is ignored here
473+ return _batched_lu_solve (lu , piv , b , trans = trans , res_type = res_type )
474+
475+ if lu .shape [0 ] != b .shape [0 ]:
476+ raise ValueError (
477+ f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
478+ )
479+
334480 lu_usm_arr = dpnp .get_usm_ndarray (lu )
335481 b_usm_arr = dpnp .get_usm_ndarray (b )
336482
@@ -377,18 +523,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
377523 b_h = b
378524 dep_evs = [lu_copy_ev ]
379525
380- if not isinstance (trans , int ):
381- raise TypeError ("`trans` must be an integer" )
382-
383- # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
384- if trans == 0 :
385- trans_mkl = li .Transpose .N
386- elif trans == 1 :
387- trans_mkl = li .Transpose .T
388- elif trans == 2 :
389- trans_mkl = li .Transpose .C
390- else :
391- raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
526+ trans_mkl = _map_trans_to_mkl (trans )
392527
393528 # Call the LAPACK extension function _getrs
394529 # to solve the system of linear equations with an LU-factored
0 commit comments