55
55
]
56
56
57
57
58
+ def _align_lu_solve_broadcast (lu , b ):
59
+ """Align LU and RHS batch dimensions with SciPy-like rules."""
60
+ lu_shape = lu .shape
61
+ b_shape = b .shape
62
+
63
+ if b .ndim < 2 :
64
+ if lu_shape [- 2 ] != b_shape [0 ]:
65
+ raise ValueError (
66
+ f"Shapes of lu { lu_shape } and b { b_shape } are incompatible"
67
+ )
68
+ b = dpnp .broadcast_to (b , lu_shape [:- 1 ])
69
+ return lu , b
70
+
71
+ if lu_shape [- 2 ] != b_shape [- 2 ]:
72
+ raise ValueError (
73
+ f"Shapes of lu { lu_shape } and b { b_shape } are incompatible"
74
+ )
75
+
76
+ # Use dpnp.broadcast_shapes() to align the resulting batch shapes
77
+ batch = dpnp .broadcast_shapes (lu_shape [:- 2 ], b_shape [:- 2 ])
78
+ lu_bshape = batch + lu_shape [- 2 :]
79
+ b_bshape = batch + b_shape [- 2 :]
80
+
81
+ if lu_shape != lu_bshape :
82
+ lu = dpnp .broadcast_to (lu , lu_bshape )
83
+ if b_shape != b_bshape :
84
+ b = dpnp .broadcast_to (b , b_bshape )
85
+
86
+ return lu , b
87
+
88
+
58
89
def _batched_lu_factor_scipy (a , res_type ): # pylint: disable=too-many-locals
59
90
"""SciPy-compatible LU factorization for batched inputs."""
60
91
@@ -180,6 +211,106 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
180
211
return (a_h , ipiv_h )
181
212
182
213
214
+ def _batched_lu_solve (lu , piv , b , res_type , trans = 0 ):
215
+ """Solve a batched equation system (SciPy-compatible behavior)."""
216
+ res_usm_type , exec_q = get_usm_allocations ([lu , piv , b ])
217
+
218
+ if b .size == 0 :
219
+ return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
220
+
221
+ b_ndim = b .ndim
222
+
223
+ lu , b = _align_lu_solve_broadcast (lu , b )
224
+
225
+ n = lu .shape [- 1 ]
226
+ nrhs = b .shape [- 1 ] if b_ndim > 1 else 1
227
+
228
+ # get 3d input arrays by reshape
229
+ if lu .ndim > 3 :
230
+ lu = dpnp .reshape (lu , (- 1 , n , n ))
231
+ # get 2d pivot arrays by reshape
232
+ if piv .ndim > 2 :
233
+ piv = dpnp .reshape (piv , (- 1 , n ))
234
+ batch_size = lu .shape [0 ]
235
+
236
+ # Move batch axis to the end (n, n, batch) in Fortran order:
237
+ # required by getrs_batch
238
+ # and ensures each a[..., i] is F-contiguous for getrs_batch
239
+ lu = dpnp .moveaxis (lu , 0 , - 1 )
240
+
241
+ b_orig_shape = b .shape
242
+ if b .ndim > 2 :
243
+ b = dpnp .reshape (b , (- 1 , n , nrhs ))
244
+
245
+ # Move batch axis to the end (n, nrhs, batch) in Fortran order:
246
+ # required by getrs_batch
247
+ # and ensures each b[..., i] is F-contiguous for getrs_batch
248
+ b = dpnp .moveaxis (b , 0 , - 1 )
249
+
250
+ lu_usm_arr = dpnp .get_usm_ndarray (lu )
251
+ b_usm_arr = dpnp .get_usm_ndarray (b )
252
+
253
+ # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
254
+ # convert to 1-based for oneMKL getrs_batch
255
+ piv_h = piv + 1
256
+
257
+ _manager = dpu .SequentialOrderManager [exec_q ]
258
+ dep_evs = _manager .submitted_events
259
+
260
+ # oneMKL LAPACK getrs overwrites `lu`.
261
+ lu_h = dpnp .empty_like (lu , order = "F" , dtype = res_type , usm_type = res_usm_type )
262
+
263
+ # use DPCTL tensor function to fill the сopy of the input array
264
+ # from the input array
265
+ ht_ev , lu_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
266
+ src = lu_usm_arr ,
267
+ dst = lu_h .get_array (),
268
+ sycl_queue = lu .sycl_queue ,
269
+ depends = dep_evs ,
270
+ )
271
+ _manager .add_event_pair (ht_ev , lu_copy_ev )
272
+
273
+ b_h = dpnp .empty_like (b , order = "F" , dtype = res_type , usm_type = res_usm_type )
274
+ ht_ev , b_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
275
+ src = b_usm_arr ,
276
+ dst = b_h .get_array (),
277
+ sycl_queue = b .sycl_queue ,
278
+ depends = dep_evs ,
279
+ )
280
+ _manager .add_event_pair (ht_ev , b_copy_ev )
281
+ dep_evs = [lu_copy_ev , b_copy_ev ]
282
+
283
+ lu_stride = n * n
284
+ piv_stride = n
285
+ b_stride = n * nrhs
286
+
287
+ trans_mkl = _map_trans_to_mkl (trans )
288
+
289
+ # Call the LAPACK extension function _getrs_batch
290
+ # to solve the system of linear equations with an LU-factored
291
+ # coefficient square matrix, with multiple right-hand sides.
292
+ ht_ev , getrs_batch_ev = li ._getrs_batch (
293
+ exec_q ,
294
+ lu_h .get_array (),
295
+ piv_h .get_array (),
296
+ b_h .get_array (),
297
+ trans_mkl ,
298
+ n ,
299
+ nrhs ,
300
+ lu_stride ,
301
+ piv_stride ,
302
+ b_stride ,
303
+ batch_size ,
304
+ depends = dep_evs ,
305
+ )
306
+ _manager .add_event_pair (ht_ev , getrs_batch_ev )
307
+
308
+ # Restore original shape: move batch axis back and reshape
309
+ b_h = dpnp .moveaxis (b_h , - 1 , 0 ).reshape (b_orig_shape )
310
+
311
+ return b_h
312
+
313
+
183
314
def _is_copy_required (a , res_type ):
184
315
"""
185
316
Determine if `a` needs to be copied before LU decomposition.
@@ -197,6 +328,20 @@ def _is_copy_required(a, res_type):
197
328
return False
198
329
199
330
331
+ def _map_trans_to_mkl (trans ):
332
+ """Map SciPy-style trans code (0,1,2) to oneMKL transpose enum."""
333
+ if not isinstance (trans , int ):
334
+ raise TypeError ("`trans` must be an integer" )
335
+
336
+ if trans == 0 :
337
+ return li .Transpose .N
338
+ if trans == 1 :
339
+ return li .Transpose .T
340
+ if trans == 2 :
341
+ return li .Transpose .C
342
+ raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
343
+
344
+
200
345
def dpnp_lu_factor (a , overwrite_a = False , check_finite = True ):
201
346
"""
202
347
dpnp_lu_factor(a, overwrite_a=False, check_finite=True)
@@ -307,18 +452,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
307
452
308
453
res_type = _common_type (lu , b )
309
454
310
- # TODO: add broadcasting
311
- if lu .shape [0 ] != b .shape [0 ]:
312
- raise ValueError (
313
- f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
314
- )
315
-
316
455
if b .size == 0 :
317
456
return dpnp .empty_like (b , dtype = res_type , usm_type = res_usm_type )
318
457
319
- if lu .ndim > 2 :
320
- raise NotImplementedError ("Batched matrices are not supported" )
321
-
322
458
if check_finite :
323
459
if not dpnp .isfinite (lu ).all ():
324
460
raise ValueError (
@@ -331,6 +467,16 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
331
467
"Right-hand side array must not contain infs or NaNs"
332
468
)
333
469
470
+ if lu .ndim > 2 :
471
+ # SciPy always copies each 2D slice,
472
+ # so `overwrite_b` is ignored here
473
+ return _batched_lu_solve (lu , piv , b , trans = trans , res_type = res_type )
474
+
475
+ if lu .shape [0 ] != b .shape [0 ]:
476
+ raise ValueError (
477
+ f"Shapes of lu { lu .shape } and b { b .shape } are incompatible"
478
+ )
479
+
334
480
lu_usm_arr = dpnp .get_usm_ndarray (lu )
335
481
b_usm_arr = dpnp .get_usm_ndarray (b )
336
482
@@ -377,18 +523,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
377
523
b_h = b
378
524
dep_evs = [lu_copy_ev ]
379
525
380
- if not isinstance (trans , int ):
381
- raise TypeError ("`trans` must be an integer" )
382
-
383
- # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
384
- if trans == 0 :
385
- trans_mkl = li .Transpose .N
386
- elif trans == 1 :
387
- trans_mkl = li .Transpose .T
388
- elif trans == 2 :
389
- trans_mkl = li .Transpose .C
390
- else :
391
- raise ValueError ("`trans` must be 0 (N), 1 (T), or 2 (C)" )
526
+ trans_mkl = _map_trans_to_mkl (trans )
392
527
393
528
# Call the LAPACK extension function _getrs
394
529
# to solve the system of linear equations with an LU-factored
0 commit comments