Skip to content

Commit a13cd88

Browse files
Compute strides explicitly
1 parent 5e21a02 commit a13cd88

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -586,12 +586,9 @@ def _batched_lu_solve(lu, piv, b, res_type, trans=0):
586586
_manager.add_event_pair(ht_ev, b_copy_ev)
587587
dep_evs = [lu_copy_ev, b_copy_ev]
588588

589-
lu_stride = lu_h.strides[-1]
590-
piv_stride = piv.strides[0]
591-
b_stride = b_h.strides[-1]
592-
593-
if not isinstance(trans, int):
594-
raise TypeError("`trans` must be an integer")
589+
lu_stride = n * n
590+
piv_stride = n
591+
b_stride = n * nrhs
595592

596593
trans_mkl = _map_trans_to_mkl(trans)
597594

0 commit comments

Comments
 (0)