Skip to content

Commit 183293b

Browse files
Apply remarks
1 parent ee24137 commit 183293b

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

dpnp/scipy/linalg/_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,12 @@ def _batched_lu_solve(lu, piv, b, res_type, trans=0):
218218
"""Solve a batched equation system (SciPy-compatible behavior)."""
219219
res_usm_type, exec_q = get_usm_allocations([lu, piv, b])
220220

221-
b_ndim = b.ndim
221+
b_ndim_orig = b.ndim
222222

223223
lu, b = _align_lu_solve_broadcast(lu, b)
224224

225225
n = lu.shape[-1]
226-
nrhs = b.shape[-1] if b_ndim > 1 else 1
226+
nrhs = b.shape[-1] if b_ndim_orig > 1 else 1
227227

228228
# get 3d input arrays by reshape
229229
if lu.ndim > 3:
@@ -235,11 +235,11 @@ def _batched_lu_solve(lu, piv, b, res_type, trans=0):
235235

236236
# Move batch axis to the end (n, n, batch) in Fortran order:
237237
# required by getrs_batch
238-
# and ensures each a[..., i] is F-contiguous for getrs_batch
238+
# and ensures each lu[..., i] is F-contiguous for getrs_batch
239239
lu = dpnp.moveaxis(lu, 0, -1)
240240

241241
b_orig_shape = b.shape
242-
if b.ndim > 2:
242+
if b.ndim > 3:
243243
b = dpnp.reshape(b, (-1, n, nrhs))
244244

245245
# Move batch axis to the end (n, nrhs, batch) in Fortran order:

dpnp/tests/test_linalg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,8 +2339,7 @@ def test_empty_shapes(self, a_shape, b_shape):
23392339
n = a_shape[0]
23402340

23412341
if n > 0:
2342-
for i in range(n):
2343-
a_dp[i, i] = a_dp.dtype.type(1.0)
2342+
dpnp.fill_diagonal(a_dp, a_dp.dtype.type(1.0))
23442343
b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F")
23452344

23462345
lu, piv = dpnp.scipy.linalg.lu_factor(a_dp, check_finite=False)

0 commit comments

Comments
 (0)