Skip to content

Commit 9345b7b

Browse files
Add trans code handling
1 parent 1e09cb7 commit 9345b7b

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2611,6 +2611,19 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
26112611
# MKL lapack uses 1-origin while SciPy uses 0-origin
26122612
piv_h += 1
26132613

2614+
if not isinstance(trans, int):
2615+
raise TypeError("`trans` must be an integer")
2616+
2617+
# Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
2618+
if trans == 0:
2619+
trans_mkl = li.Transpose.N
2620+
elif trans == 1:
2621+
trans_mkl = li.Transpose.T
2622+
elif trans == 2:
2623+
trans_mkl = li.Transpose.C
2624+
else:
2625+
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)")
2626+
26142627
# Call the LAPACK extension function _getrs
26152628
# to solve the system of linear equations with an LU-factored
26162629
# coefficient square matrix, with multiple right-hand sides.
@@ -2619,7 +2632,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
26192632
lu_h.get_array(),
26202633
piv_h.get_array(),
26212634
b_h.get_array(),
2622-
trans,
2635+
trans_mkl,
26232636
depends=dep_evs,
26242637
)
26252638
_manager.add_event_pair(ht_ev, getrs_ev)

0 commit comments

Comments
 (0)