Skip to content

Commit 46deac3

Browse files
Update _batched_lu_factor_scipy by using single allocation with batch-axis views
1 parent af74642 commit 46deac3

File tree

1 file changed

+56
-97
lines changed

1 file changed

+56
-97
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 56 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,9 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
402402
"""SciPy-compatible LU factorization for batched inputs."""
403403

404404
# TODO: Find out at which array sizes the best performance is obtained
405-
# getrf_batch implementation shows slow results with large arrays on GPU.
405+
# getrf_batch can be slow on large GPU arrays.
406406
# Use getrf_batch only on CPU.
407-
# On GPU call getrf for each two-dimensional array by loop
407+
# On GPU fall back to calling getrf per 2D slice.
408408
use_batch = a.sycl_device.has_aspect_cpu
409409

410410
a_sycl_queue = a.sycl_queue
@@ -416,7 +416,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
416416
orig_shape = a.shape
417417
batch_shape = orig_shape[:-2]
418418

419-
# accommodate empty arrays
419+
# handle empty input
420420
if a.size == 0:
421421
lu = dpnp.empty_like(a)
422422
piv = dpnp.empty(
@@ -431,32 +431,33 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
431431
a = dpnp.reshape(a, (-1, m, n))
432432
batch_size = a.shape[0]
433433

434-
if use_batch:
435-
# Reorder the elements by moving the last two axes of `a` to the front
436-
# to match fortran-like array order which is assumed by getrf_batch
437-
a = dpnp.moveaxis(a, 0, -1)
434+
# Move batch axis to the end (m, n, batch) in Fortran order:
435+
# required by getrf_batch
436+
# and ensures each a[..., i] is F-contiguous for getrf
437+
a = dpnp.moveaxis(a, 0, -1)
438438

439-
a_usm_arr = dpnp.get_usm_ndarray(a)
439+
a_usm_arr = dpnp.get_usm_ndarray(a)
440440

441-
# `a` must be copied because getrf_batch destroys the input matrix
442-
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
443-
ipiv_h = dpnp.empty(
444-
(batch_size, k),
445-
dtype=dpnp.int64,
446-
order="C",
447-
usm_type=a_usm_type,
448-
sycl_queue=a_sycl_queue,
449-
)
441+
# `a` must be copied because getrf/getrf_batch destroys the input matrix
442+
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
443+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
444+
src=a_usm_arr,
445+
dst=a_h.get_array(),
446+
sycl_queue=a_sycl_queue,
447+
depends=_manager.submitted_events,
448+
)
449+
_manager.add_event_pair(ht_ev, copy_ev)
450450

451-
dev_info_h = [0] * batch_size
451+
ipiv_h = dpnp.empty(
452+
(batch_size, k),
453+
dtype=dpnp.int64,
454+
order="C",
455+
usm_type=a_usm_type,
456+
sycl_queue=a_sycl_queue,
457+
)
452458

453-
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
454-
src=a_usm_arr,
455-
dst=a_h.get_array(),
456-
sycl_queue=a_sycl_queue,
457-
depends=_manager.submitted_events,
458-
)
459-
_manager.add_event_pair(ht_ev, copy_ev)
459+
if use_batch:
460+
dev_info_h = [0] * batch_size
460461

461462
ipiv_stride = k
462463
a_stride = a_h.strides[-1]
@@ -477,12 +478,6 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
477478
)
478479
_manager.add_event_pair(ht_ev, getrf_ev)
479480

480-
# getrf_batch expects `a` in Fortran order and overwrites it.
481-
# Batch was moved to the last axis before the call.
482-
# Move it back to the front and reshape to the original shape.
483-
a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape)
484-
ipiv_h = ipiv_h.reshape((*batch_shape, k))
485-
486481
if any(dev_info_h):
487482
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
488483
warn(
@@ -491,77 +486,41 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
491486
RuntimeWarning,
492487
stacklevel=2,
493488
)
489+
else:
490+
dev_info_vecs = [[0] for _ in range(batch_size)]
491+
492+
# Sequential LU factorization using getrf per slice
493+
for i in range(batch_size):
494+
ht_ev, getrf_ev = li._getrf(
495+
a_sycl_queue,
496+
a_h[..., i].get_array(),
497+
ipiv_h[i].get_array(),
498+
dev_info_vecs[i],
499+
depends=[copy_ev],
500+
)
501+
_manager.add_event_pair(ht_ev, getrf_ev)
494502

495-
# MKL lapack uses 1-origin while SciPy uses 0-origin
496-
ipiv_h -= 1
497-
498-
# Return a tuple containing the factorized matrix 'a_h',
499-
# pivot indices 'ipiv_h'
500-
return (a_h, ipiv_h)
501-
502-
a_usm_arr = dpnp.get_usm_ndarray(a)
503-
504-
# Initialize lists for storing arrays and events for each batch
505-
a_vecs = [None] * batch_size
506-
ipiv_vecs = [None] * batch_size
507-
dev_info_vecs = [None] * batch_size
508-
dep_evs = _manager.submitted_events
509-
510-
# Process each batch
511-
for i in range(batch_size):
512-
# Copy each 2D slice to a new array because getrf will destroy
513-
# the input matrix
514-
a_vecs[i] = dpnp.empty_like(a[i], order="F", dtype=res_type)
515-
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
516-
src=a_usm_arr[i],
517-
dst=a_vecs[i].get_array(),
518-
sycl_queue=a_sycl_queue,
519-
depends=dep_evs,
520-
)
521-
_manager.add_event_pair(ht_ev, copy_ev)
522-
523-
ipiv_vecs[i] = dpnp.empty(
524-
(k,),
525-
dtype=dpnp.int64,
526-
order="C",
527-
usm_type=a_usm_type,
528-
sycl_queue=a_sycl_queue,
503+
diag_nums = ", ".join(
504+
str(v) for info in dev_info_vecs for v in info if v > 0
529505
)
506+
if diag_nums:
507+
warn(
508+
f"Diagonal number {diag_nums} are exactly zero. "
509+
"Singular matrix.",
510+
RuntimeWarning,
511+
stacklevel=2,
512+
)
530513

531-
dev_info_vecs[i] = [0]
532-
533-
# Call the LAPACK extension function _getrf
534-
# to perform LU decomposition on each batch in 'a_vecs[i]'
535-
ht_ev, getrf_ev = li._getrf(
536-
a_sycl_queue,
537-
a_vecs[i].get_array(),
538-
ipiv_vecs[i].get_array(),
539-
dev_info_vecs[i],
540-
depends=[copy_ev],
541-
)
542-
_manager.add_event_pair(ht_ev, getrf_ev)
543-
544-
# Reshape the results back to their original shape
545-
out_a = dpnp.array(a_vecs).reshape(orig_shape)
546-
out_ipiv = dpnp.array(ipiv_vecs).reshape((*batch_shape, k))
547-
548-
diag_nums = ", ".join(
549-
str(v) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0
550-
)
551-
552-
if diag_nums:
553-
warn(
554-
f"Diagonal number {diag_nums} are exactly zero. Singular matrix.",
555-
RuntimeWarning,
556-
stacklevel=2,
557-
)
514+
# Restore original shape: move batch axis back and reshape
515+
a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape)
516+
ipiv_h = ipiv_h.reshape((*batch_shape, k))
558517

559-
# MKL lapack uses 1-origin while SciPy uses 0-origin
560-
out_ipiv -= 1
518+
# oneMKL LAPACK uses 1-origin while SciPy uses 0-origin
519+
ipiv_h -= 1
561520

562-
# Return a tuple containing the factorized matrix 'out_a',
563-
# pivot indices 'out_ipiv'
564-
return (out_a, out_ipiv)
521+
# Return a tuple containing the factorized matrix 'a_h',
522+
# pivot indices 'ipiv_h'
523+
return (a_h, ipiv_h)
565524

566525

567526
def _batched_solve(a, b, exec_q, res_usm_type, res_type):

0 commit comments

Comments
 (0)