Skip to content

Commit 14feed3

Browse files
Implement _batched_lu_solve
1 parent 4270823 commit 14feed3

File tree

1 file changed

+159
-21
lines changed

1 file changed

+159
-21
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 159 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,37 @@ class SVDResult(NamedTuple):
107107
}
108108

109109

110+
def _align_lu_solve_broadcast(lu, b):
111+
"""Align LU and RHS batch dimensions with SciPy-like rules."""
112+
lu_shape = lu.shape
113+
b_shape = b.shape
114+
115+
if b.ndim < 2:
116+
if lu_shape[-2] != b_shape[0]:
117+
raise ValueError(
118+
f"Shapes of lu {lu_shape} and b {b_shape} are incompatible"
119+
)
120+
b = dpnp.broadcast_to(b, lu_shape[:-1])
121+
return lu, b
122+
123+
if lu_shape[-2] != b_shape[-2]:
124+
raise ValueError(
125+
f"Shapes of lu {lu_shape} and b {b_shape} are incompatible"
126+
)
127+
128+
# Use dpnp.broadcast_shapes() to align the resulting batch shapes
129+
batch = dpnp.broadcast_shapes(lu_shape[:-2], b_shape[:-2])
130+
lu_bshape = batch + lu_shape[-2:]
131+
b_bshape = batch + b_shape[-2:]
132+
133+
if lu_shape != lu_bshape:
134+
lu = dpnp.broadcast_to(lu, lu_bshape)
135+
if b_shape != b_bshape:
136+
b = dpnp.broadcast_to(b, b_bshape)
137+
138+
return lu, b
139+
140+
110141
def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
111142
"""
112143
_batched_eigh(a, UPLO, eigen_mode, w_type, v_type)
@@ -486,6 +517,109 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
486517
return (a_h, ipiv_h)
487518

488519

520+
def _batched_lu_solve(lu, piv, b, res_type, trans=0):
521+
"""Solve a batched equation system (SciPy-compatible behavior)."""
522+
res_usm_type, exec_q = get_usm_allocations([lu, piv, b])
523+
524+
if b.size == 0:
525+
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)
526+
527+
b_ndim = b.ndim
528+
529+
lu, b = _align_lu_solve_broadcast(lu, b)
530+
531+
n = lu.shape[-1]
532+
nrhs = b.shape[-1] if b_ndim > 1 else 1
533+
534+
# get 3d input arrays by reshape
535+
if lu.ndim > 3:
536+
lu = dpnp.reshape(lu, (-1, n, n))
537+
# get 2d pivot arrays by reshape
538+
if piv.ndim > 2:
539+
piv = dpnp.reshape(piv, (-1, n))
540+
batch_size = lu.shape[0]
541+
542+
# Move batch axis to the end (n, n, batch) in Fortran order:
543+
# required by getrs_batch
544+
# and ensures each a[..., i] is F-contiguous for getrs_batch
545+
lu = dpnp.moveaxis(lu, 0, -1)
546+
547+
b_orig_shape = b.shape
548+
if b.ndim > 2:
549+
b = dpnp.reshape(b, (-1, n, nrhs))
550+
551+
# Move batch axis to the end (n, nrhs, batch) in Fortran order:
552+
# required by getrs_batch
553+
# and ensures each b[..., i] is F-contiguous for getrs_batch
554+
b = dpnp.moveaxis(b, 0, -1)
555+
556+
lu_usm_arr = dpnp.get_usm_ndarray(lu)
557+
b_usm_arr = dpnp.get_usm_ndarray(b)
558+
559+
# dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
560+
# convert to 1-based for oneMKL getrs_batch
561+
piv_h = piv + 1
562+
563+
_manager = dpu.SequentialOrderManager[exec_q]
564+
dep_evs = _manager.submitted_events
565+
566+
# oneMKL LAPACK getrs overwrites `lu`.
567+
lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type)
568+
569+
# use DPCTL tensor function to fill the сopy of the input array
570+
# from the input array
571+
ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
572+
src=lu_usm_arr,
573+
dst=lu_h.get_array(),
574+
sycl_queue=lu.sycl_queue,
575+
depends=dep_evs,
576+
)
577+
_manager.add_event_pair(ht_ev, lu_copy_ev)
578+
579+
b_h = dpnp.empty_like(b, order="F", dtype=res_type, usm_type=res_usm_type)
580+
ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
581+
src=b_usm_arr,
582+
dst=b_h.get_array(),
583+
sycl_queue=b.sycl_queue,
584+
depends=dep_evs,
585+
)
586+
_manager.add_event_pair(ht_ev, b_copy_ev)
587+
dep_evs = [lu_copy_ev, b_copy_ev]
588+
589+
lu_stride = lu_h.strides[-1]
590+
piv_stride = piv.strides[0]
591+
b_stride = b_h.strides[-1]
592+
593+
if not isinstance(trans, int):
594+
raise TypeError("`trans` must be an integer")
595+
596+
trans_mkl = _map_trans_to_mkl(trans)
597+
598+
# Call the LAPACK extension function _getrs_batch
599+
# to solve the system of linear equations with an LU-factored
600+
# coefficient square matrix, with multiple right-hand sides.
601+
ht_ev, getrs_batch_ev = li._getrs_batch(
602+
exec_q,
603+
lu_h.get_array(),
604+
piv_h.get_array(),
605+
b_h.get_array(),
606+
trans_mkl,
607+
n,
608+
nrhs,
609+
lu_stride,
610+
piv_stride,
611+
b_stride,
612+
batch_size,
613+
depends=dep_evs,
614+
)
615+
_manager.add_event_pair(ht_ev, getrs_batch_ev)
616+
617+
# Restore original shape: move batch axis back and reshape
618+
b_h = dpnp.moveaxis(b_h, -1, 0).reshape(b_orig_shape)
619+
620+
return b_h
621+
622+
489623
def _batched_solve(a, b, exec_q, res_usm_type, res_type):
490624
"""
491625
_batched_solve(a, b, exec_q, res_usm_type, res_type)
@@ -1099,6 +1233,20 @@ def _is_empty_2d(arr):
10991233
return arr.size == 0 and numpy.prod(arr.shape[-2:]) == 0
11001234

11011235

1236+
def _map_trans_to_mkl(trans):
1237+
"""Map SciPy-style trans code (0,1,2) to oneMKL transpose enum."""
1238+
if not isinstance(trans, int):
1239+
raise TypeError("`trans` must be an integer")
1240+
1241+
if trans == 0:
1242+
return li.Transpose.N
1243+
if trans == 1:
1244+
return li.Transpose.T
1245+
if trans == 2:
1246+
return li.Transpose.C
1247+
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)")
1248+
1249+
11021250
def _lu_factor(a, res_type):
11031251
"""
11041252
Compute pivoted LU decomposition.
@@ -2493,18 +2641,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
24932641

24942642
res_type = _common_type(lu, b)
24952643

2496-
# TODO: add broadcasting
2497-
if lu.shape[0] != b.shape[0]:
2498-
raise ValueError(
2499-
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible"
2500-
)
2501-
25022644
if b.size == 0:
25032645
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)
25042646

2505-
if lu.ndim > 2:
2506-
raise NotImplementedError("Batched matrices are not supported")
2507-
25082647
if check_finite:
25092648
if not dpnp.isfinite(lu).all():
25102649
raise ValueError(
@@ -2517,6 +2656,16 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
25172656
"Right-hand side array must not contain infs or NaNs"
25182657
)
25192658

2659+
if lu.ndim > 2:
2660+
# SciPy always copies each 2D slice,
2661+
# so `overwrite_b` is ignored here
2662+
return _batched_lu_solve(lu, piv, b, trans=trans, res_type=res_type)
2663+
2664+
if lu.shape[0] != b.shape[0]:
2665+
raise ValueError(
2666+
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible"
2667+
)
2668+
25202669
lu_usm_arr = dpnp.get_usm_ndarray(lu)
25212670
b_usm_arr = dpnp.get_usm_ndarray(b)
25222671

@@ -2563,18 +2712,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
25632712
b_h = b
25642713
dep_evs = [lu_copy_ev]
25652714

2566-
if not isinstance(trans, int):
2567-
raise TypeError("`trans` must be an integer")
2568-
2569-
# Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
2570-
if trans == 0:
2571-
trans_mkl = li.Transpose.N
2572-
elif trans == 1:
2573-
trans_mkl = li.Transpose.T
2574-
elif trans == 2:
2575-
trans_mkl = li.Transpose.C
2576-
else:
2577-
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)")
2715+
trans_mkl = _map_trans_to_mkl(trans)
25782716

25792717
# Call the LAPACK extension function _getrs
25802718
# to solve the system of linear equations with an LU-factored

0 commit comments

Comments
 (0)