Skip to content

Commit b2d2b3a

Browse files
Implement of dpnp.linalg.lu_solve for 2D inputs
1 parent ce878e6 commit b2d2b3a

File tree

2 files changed

+183
-0
lines changed

2 files changed

+183
-0
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
dpnp_inv,
5858
dpnp_lstsq,
5959
dpnp_lu_factor,
60+
dpnp_lu_solve,
6061
dpnp_matrix_power,
6162
dpnp_matrix_rank,
6263
dpnp_multi_dot,
@@ -81,6 +82,7 @@
8182
"inv",
8283
"lstsq",
8384
"lu_factor",
85+
"lu_solve",
8486
"matmul",
8587
"matrix_norm",
8688
"matrix_power",
@@ -966,6 +968,75 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
966968
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite)
967969

968970

971+
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
972+
"""
973+
Solve an equation system, a x = b, given the LU factorization of `a`
974+
975+
For full documentation refer to :obj:`scipy.linalg.lu_solve`.
976+
977+
Parameters
978+
----------
979+
(lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays}
980+
LU factorization of matrix `a` ((M, N)) together with pivot indices.
981+
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
982+
Right-hand side
983+
trans : {0, 1, 2} , optional
984+
Type of system to solve:
985+
986+
===== =========
987+
trans system
988+
===== =========
989+
0 a x = b
990+
1 a^T x = b
991+
2 a^H x = b
992+
===== =========
993+
overwrite_b : {None, bool}, optional
994+
Whether to overwrite data in `b` (may increase performance).
995+
996+
Default: ``False``.
997+
check_finite : {None, bool}, optional
998+
Whether to check that the input matrix contains only finite numbers.
999+
Disabling may give a performance gain, but may result in problems
1000+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
1001+
1002+
Default: ``True``.
1003+
1004+
Returns
1005+
-------
1006+
x : {(M,), (M, K)} dpnp.ndarray
1007+
Solution to the system
1008+
1009+
Warning
1010+
-------
1011+
This function synchronizes in order to validate array elements
1012+
when ``check_finite=True``.
1013+
1014+
Examples
1015+
--------
1016+
>>> import dpnp as np
1017+
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
1018+
>>> b = np.array([1, 1, 1, 1])
1019+
>>> lu, piv = np.linalg.lu_factor(A)
1020+
>>> x = np.linalg.lu_solve((lu, piv), b)
1021+
>>> np.allclose(A @ x - b, np.zeros((4,)))
1022+
array(True)
1023+
1024+
"""
1025+
1026+
(lu, piv) = lu_and_piv
1027+
dpnp.check_supported_arrays_type(lu, piv, b)
1028+
assert_stacked_2d(lu)
1029+
1030+
return dpnp_lu_solve(
1031+
lu,
1032+
piv,
1033+
b,
1034+
trans=trans,
1035+
overwrite_b=overwrite_b,
1036+
check_finite=check_finite,
1037+
)
1038+
1039+
9691040
def matmul(x1, x2, /):
9701041
"""
9711042
Computes the matrix product.

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2514,6 +2514,118 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
25142514
return (a_h, ipiv_h)
25152515

25162516

2517+
def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2518+
"""
2519+
dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True)
2520+
2521+
Solve an equation system (SciPy-compatible behavior).
2522+
2523+
This function mimics the behavior of `scipy.linalg.lu_solve` including
2524+
support for `trans`, `overwrite_b`, `check_finite`,
2525+
and 0-based pivot indexing.
2526+
2527+
"""
2528+
2529+
res_usm_type, exec_q = get_usm_allocations([lu, piv, b])
2530+
2531+
res_type = _common_type(lu, b)
2532+
2533+
# TODO: add broadcasting
2534+
if lu.shape[0] != b.shape[0]:
2535+
raise ValueError(
2536+
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible"
2537+
)
2538+
2539+
if b.size == 0:
2540+
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)
2541+
2542+
if lu.ndim > 2:
2543+
raise NotImplementedError("Batched matrices are not supported")
2544+
2545+
if check_finite:
2546+
if not dpnp.isfinite(lu).all():
2547+
raise ValueError(
2548+
"array must not contain infs or NaNs.\n"
2549+
"Note that when a singular matrix is given, unlike "
2550+
"dpnp.linalg.lu_factor returns an array containing NaN."
2551+
)
2552+
if not dpnp.isfinite(b).all():
2553+
raise ValueError("array must not contain infs or NaNs")
2554+
2555+
lu_usm_arr = dpnp.get_usm_ndarray(lu)
2556+
piv_usm_arr = dpnp.get_usm_ndarray(piv)
2557+
b_usm_arr = dpnp.get_usm_ndarray(b)
2558+
2559+
_manager = dpu.SequentialOrderManager[exec_q]
2560+
dep_evs = _manager.submitted_events
2561+
2562+
# oneMKL LAPACK getrf overwrites `a`.
2563+
lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type)
2564+
2565+
# use DPCTL tensor function to fill the сopy of the input array
2566+
# from the input array
2567+
ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2568+
src=lu_usm_arr,
2569+
dst=lu_h.get_array(),
2570+
sycl_queue=lu.sycl_queue,
2571+
depends=dep_evs,
2572+
)
2573+
_manager.add_event_pair(ht_ev, lu_copy_ev)
2574+
2575+
# SciPy-compatible behavior
2576+
# Copy is required if:
2577+
# - overwrite_a is False (always copy),
2578+
# - dtype mismatch,
2579+
# - not F-contiguous,s
2580+
# - not writeable
2581+
if not overwrite_b or _is_copy_required(b, res_type):
2582+
b_h = dpnp.empty_like(
2583+
b, order="F", dtype=res_type, usm_type=res_usm_type
2584+
)
2585+
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2586+
src=b_usm_arr,
2587+
dst=b_h.get_array(),
2588+
sycl_queue=b.sycl_queue,
2589+
depends=_manager.submitted_events,
2590+
)
2591+
_manager.add_event_pair(ht_ev, dep_ev)
2592+
dep_ev = [dep_ev]
2593+
else:
2594+
# input is suitable for in-place modification
2595+
b_h = b
2596+
dep_ev = _manager.submitted_events
2597+
2598+
# oneMKL LAPACK getrf overwrites `a`.
2599+
piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type)
2600+
2601+
# use DPCTL tensor function to fill the сopy of the pivot array
2602+
# from the pivot array
2603+
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2604+
src=piv_usm_arr,
2605+
dst=piv_h.get_array(),
2606+
sycl_queue=piv.sycl_queue,
2607+
depends=dep_evs,
2608+
)
2609+
_manager.add_event_pair(ht_ev, piv_copy_ev)
2610+
# MKL lapack uses 1-origin while SciPy uses 0-origin
2611+
piv_h += 1
2612+
2613+
# Call the LAPACK extension function _getrs
2614+
# to solve the system of linear equations with an LU-factored
2615+
# coefficient square matrix, with multiple right-hand sides.
2616+
ht_ev, getrs_ev = li._getrs(
2617+
exec_q,
2618+
lu_h.get_array(),
2619+
piv_h.get_array(),
2620+
b_h.get_array(),
2621+
trans,
2622+
depends=dep_ev,
2623+
)
2624+
_manager.add_event_pair(ht_ev, getrs_ev)
2625+
2626+
return b_h
2627+
2628+
25172629
def dpnp_matrix_power(a, n):
25182630
"""
25192631
dpnp_matrix_power(a, n)

0 commit comments

Comments
 (0)