Skip to content

Commit 959f5f8

Browse files
Move changes to new location(scipy folder)
1 parent 7464a25 commit 959f5f8

File tree

4 files changed

+177
-83
lines changed

4 files changed

+177
-83
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -106,37 +106,6 @@ class SVDResult(NamedTuple):
106106
}
107107

108108

109-
def _align_lu_solve_broadcast(lu, b):
110-
"""Align LU and RHS batch dimensions with SciPy-like rules."""
111-
lu_shape = lu.shape
112-
b_shape = b.shape
113-
114-
if b.ndim < 2:
115-
if lu_shape[-2] != b_shape[0]:
116-
raise ValueError(
117-
f"Shapes of lu {lu_shape} and b {b_shape} are incompatible"
118-
)
119-
b = dpnp.broadcast_to(b, lu_shape[:-1])
120-
return lu, b
121-
122-
if lu_shape[-2] != b_shape[-2]:
123-
raise ValueError(
124-
f"Shapes of lu {lu_shape} and b {b_shape} are incompatible"
125-
)
126-
127-
# Use dpnp.broadcast_shapes() to align the resulting batch shapes
128-
batch = dpnp.broadcast_shapes(lu_shape[:-2], b_shape[:-2])
129-
lu_bshape = batch + lu_shape[-2:]
130-
b_bshape = batch + b_shape[-2:]
131-
132-
if lu_shape != lu_bshape:
133-
lu = dpnp.broadcast_to(lu, lu_bshape)
134-
if b_shape != b_bshape:
135-
b = dpnp.broadcast_to(b, b_bshape)
136-
137-
return lu, b
138-
139-
140109
def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
141110
"""
142111
_batched_eigh(a, UPLO, eigen_mode, w_type, v_type)
@@ -987,20 +956,6 @@ def _is_empty_2d(arr):
987956
return arr.size == 0 and numpy.prod(arr.shape[-2:]) == 0
988957

989958

990-
def _map_trans_to_mkl(trans):
991-
"""Map SciPy-style trans code (0,1,2) to oneMKL transpose enum."""
992-
if not isinstance(trans, int):
993-
raise TypeError("`trans` must be an integer")
994-
995-
if trans == 0:
996-
return li.Transpose.N
997-
if trans == 1:
998-
return li.Transpose.T
999-
if trans == 2:
1000-
return li.Transpose.C
1001-
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)")
1002-
1003-
1004959
def _lu_factor(a, res_type):
1005960
"""
1006961
Compute pivoted LU decomposition.

dpnp/scipy/linalg/_decomp_lu.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@
3939

4040

4141
import dpnp
42-
from dpnp.linalg.dpnp_utils_linalg import assert_stacked_2d
42+
from dpnp.linalg.dpnp_utils_linalg import (
43+
assert_stacked_2d,
44+
assert_stacked_square,
45+
)
4346

4447
from ._utils import (
4548
dpnp_lu_factor,
@@ -184,6 +187,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
184187
(lu, piv) = lu_and_piv
185188
dpnp.check_supported_arrays_type(lu, piv, b)
186189
assert_stacked_2d(lu)
190+
assert_stacked_square(lu)
187191

188192
return dpnp_lu_solve(
189193
lu,

dpnp/scipy/linalg/_utils.py

Lines changed: 156 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,37 @@
5555
]
5656

5757

58+
def _align_lu_solve_broadcast(lu, b):
59+
"""Align LU and RHS batch dimensions with SciPy-like rules."""
60+
lu_shape = lu.shape
61+
b_shape = b.shape
62+
63+
if b.ndim < 2:
64+
if lu_shape[-2] != b_shape[0]:
65+
raise ValueError(
66+
f"Shapes of lu {lu_shape} and b {b_shape} are incompatible"
67+
)
68+
b = dpnp.broadcast_to(b, lu_shape[:-1])
69+
return lu, b
70+
71+
if lu_shape[-2] != b_shape[-2]:
72+
raise ValueError(
73+
f"Shapes of lu {lu_shape} and b {b_shape} are incompatible"
74+
)
75+
76+
# Use dpnp.broadcast_shapes() to align the resulting batch shapes
77+
batch = dpnp.broadcast_shapes(lu_shape[:-2], b_shape[:-2])
78+
lu_bshape = batch + lu_shape[-2:]
79+
b_bshape = batch + b_shape[-2:]
80+
81+
if lu_shape != lu_bshape:
82+
lu = dpnp.broadcast_to(lu, lu_bshape)
83+
if b_shape != b_bshape:
84+
b = dpnp.broadcast_to(b, b_bshape)
85+
86+
return lu, b
87+
88+
5889
def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
5990
"""SciPy-compatible LU factorization for batched inputs."""
6091

@@ -180,6 +211,106 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
180211
return (a_h, ipiv_h)
181212

182213

214+
def _batched_lu_solve(lu, piv, b, res_type, trans=0):
215+
"""Solve a batched equation system (SciPy-compatible behavior)."""
216+
res_usm_type, exec_q = get_usm_allocations([lu, piv, b])
217+
218+
if b.size == 0:
219+
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)
220+
221+
b_ndim = b.ndim
222+
223+
lu, b = _align_lu_solve_broadcast(lu, b)
224+
225+
n = lu.shape[-1]
226+
nrhs = b.shape[-1] if b_ndim > 1 else 1
227+
228+
# get 3d input arrays by reshape
229+
if lu.ndim > 3:
230+
lu = dpnp.reshape(lu, (-1, n, n))
231+
# get 2d pivot arrays by reshape
232+
if piv.ndim > 2:
233+
piv = dpnp.reshape(piv, (-1, n))
234+
batch_size = lu.shape[0]
235+
236+
# Move batch axis to the end (n, n, batch) in Fortran order:
237+
# required by getrs_batch
238+
# and ensures each a[..., i] is F-contiguous for getrs_batch
239+
lu = dpnp.moveaxis(lu, 0, -1)
240+
241+
b_orig_shape = b.shape
242+
if b.ndim > 2:
243+
b = dpnp.reshape(b, (-1, n, nrhs))
244+
245+
# Move batch axis to the end (n, nrhs, batch) in Fortran order:
246+
# required by getrs_batch
247+
# and ensures each b[..., i] is F-contiguous for getrs_batch
248+
b = dpnp.moveaxis(b, 0, -1)
249+
250+
lu_usm_arr = dpnp.get_usm_ndarray(lu)
251+
b_usm_arr = dpnp.get_usm_ndarray(b)
252+
253+
# dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
254+
# convert to 1-based for oneMKL getrs_batch
255+
piv_h = piv + 1
256+
257+
_manager = dpu.SequentialOrderManager[exec_q]
258+
dep_evs = _manager.submitted_events
259+
260+
# oneMKL LAPACK getrs overwrites `lu`.
261+
lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type)
262+
263+
# use DPCTL tensor function to fill the сopy of the input array
264+
# from the input array
265+
ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
266+
src=lu_usm_arr,
267+
dst=lu_h.get_array(),
268+
sycl_queue=lu.sycl_queue,
269+
depends=dep_evs,
270+
)
271+
_manager.add_event_pair(ht_ev, lu_copy_ev)
272+
273+
b_h = dpnp.empty_like(b, order="F", dtype=res_type, usm_type=res_usm_type)
274+
ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
275+
src=b_usm_arr,
276+
dst=b_h.get_array(),
277+
sycl_queue=b.sycl_queue,
278+
depends=dep_evs,
279+
)
280+
_manager.add_event_pair(ht_ev, b_copy_ev)
281+
dep_evs = [lu_copy_ev, b_copy_ev]
282+
283+
lu_stride = n * n
284+
piv_stride = n
285+
b_stride = n * nrhs
286+
287+
trans_mkl = _map_trans_to_mkl(trans)
288+
289+
# Call the LAPACK extension function _getrs_batch
290+
# to solve the system of linear equations with an LU-factored
291+
# coefficient square matrix, with multiple right-hand sides.
292+
ht_ev, getrs_batch_ev = li._getrs_batch(
293+
exec_q,
294+
lu_h.get_array(),
295+
piv_h.get_array(),
296+
b_h.get_array(),
297+
trans_mkl,
298+
n,
299+
nrhs,
300+
lu_stride,
301+
piv_stride,
302+
b_stride,
303+
batch_size,
304+
depends=dep_evs,
305+
)
306+
_manager.add_event_pair(ht_ev, getrs_batch_ev)
307+
308+
# Restore original shape: move batch axis back and reshape
309+
b_h = dpnp.moveaxis(b_h, -1, 0).reshape(b_orig_shape)
310+
311+
return b_h
312+
313+
183314
def _is_copy_required(a, res_type):
184315
"""
185316
Determine if `a` needs to be copied before LU decomposition.
@@ -197,6 +328,20 @@ def _is_copy_required(a, res_type):
197328
return False
198329

199330

331+
def _map_trans_to_mkl(trans):
332+
"""Map SciPy-style trans code (0,1,2) to oneMKL transpose enum."""
333+
if not isinstance(trans, int):
334+
raise TypeError("`trans` must be an integer")
335+
336+
if trans == 0:
337+
return li.Transpose.N
338+
if trans == 1:
339+
return li.Transpose.T
340+
if trans == 2:
341+
return li.Transpose.C
342+
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)")
343+
344+
200345
def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
201346
"""
202347
dpnp_lu_factor(a, overwrite_a=False, check_finite=True)
@@ -307,18 +452,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
307452

308453
res_type = _common_type(lu, b)
309454

310-
# TODO: add broadcasting
311-
if lu.shape[0] != b.shape[0]:
312-
raise ValueError(
313-
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible"
314-
)
315-
316455
if b.size == 0:
317456
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)
318457

319-
if lu.ndim > 2:
320-
raise NotImplementedError("Batched matrices are not supported")
321-
322458
if check_finite:
323459
if not dpnp.isfinite(lu).all():
324460
raise ValueError(
@@ -331,6 +467,16 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
331467
"Right-hand side array must not contain infs or NaNs"
332468
)
333469

470+
if lu.ndim > 2:
471+
# SciPy always copies each 2D slice,
472+
# so `overwrite_b` is ignored here
473+
return _batched_lu_solve(lu, piv, b, trans=trans, res_type=res_type)
474+
475+
if lu.shape[0] != b.shape[0]:
476+
raise ValueError(
477+
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible"
478+
)
479+
334480
lu_usm_arr = dpnp.get_usm_ndarray(lu)
335481
b_usm_arr = dpnp.get_usm_ndarray(b)
336482

@@ -377,18 +523,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
377523
b_h = b
378524
dep_evs = [lu_copy_ev]
379525

380-
if not isinstance(trans, int):
381-
raise TypeError("`trans` must be an integer")
382-
383-
# Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
384-
if trans == 0:
385-
trans_mkl = li.Transpose.N
386-
elif trans == 1:
387-
trans_mkl = li.Transpose.T
388-
elif trans == 2:
389-
trans_mkl = li.Transpose.C
390-
else:
391-
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)")
526+
trans_mkl = _map_trans_to_mkl(trans)
392527

393528
# Call the LAPACK extension function _getrs
394529
# to solve the system of linear equations with an LU-factored

0 commit comments

Comments
 (0)