Skip to content

Commit 25850f4

Browse files
Extend dpnp.linalg.lu_factor to support batched inputs
1 parent c76aba5 commit 25850f4

File tree

1 file changed

+156
-1
lines changed

1 file changed

+156
-1
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,159 @@ def _batched_lu_factor(a, res_type):
398398
return (out_a, out_ipiv, out_dev_info)
399399

400400

401+
def _batched_lu_factor_scipy(a, res_type):
402+
"""SciPy-compatible LU factorization for batched inputs."""
403+
404+
# TODO: Find out at which array sizes the best performance is obtained
405+
# getrf_batch implementation shows slow results with large arrays on GPU.
406+
# Use getrf_batch only on CPU.
407+
# On GPU call getrf for each two-dimensional array by loop
408+
use_batch = a.sycl_device.has_aspect_cpu
409+
410+
a_sycl_queue = a.sycl_queue
411+
a_usm_type = a.usm_type
412+
_manager = dpu.SequentialOrderManager[a_sycl_queue]
413+
414+
m, n = a.shape[-2:]
415+
k = min(m, n)
416+
orig_shape = a.shape
417+
# get 3d input arrays by reshape
418+
a = dpnp.reshape(a, (-1, m, n))
419+
batch_size = a.shape[0]
420+
421+
if use_batch:
422+
# Reorder the elements by moving the last two axes of `a` to the front
423+
# to match fortran-like array order which is assumed by getrf_batch
424+
a = dpnp.moveaxis(a, 0, -1)
425+
426+
a_usm_arr = dpnp.get_usm_ndarray(a)
427+
428+
# `a` must be copied because getrf_batch destroys the input matrix
429+
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
430+
ipiv_h = dpnp.empty(
431+
(batch_size, k),
432+
dtype=dpnp.int64,
433+
order="C",
434+
usm_type=a_usm_type,
435+
sycl_queue=a_sycl_queue,
436+
)
437+
438+
dev_info_h = [0] * batch_size
439+
440+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
441+
src=a_usm_arr,
442+
dst=a_h.get_array(),
443+
sycl_queue=a_sycl_queue,
444+
depends=_manager.submitted_events,
445+
)
446+
_manager.add_event_pair(ht_ev, copy_ev)
447+
448+
ipiv_stride = k
449+
a_stride = a_h.strides[-1]
450+
451+
# Call the LAPACK extension function _getrf_batch
452+
# to perform LU decomposition of a batch of general matrices
453+
ht_ev, getrf_ev = li._getrf_batch(
454+
a_sycl_queue,
455+
a_h.get_array(),
456+
ipiv_h.get_array(),
457+
dev_info_h,
458+
m,
459+
n,
460+
a_stride,
461+
ipiv_stride,
462+
batch_size,
463+
depends=[copy_ev],
464+
)
465+
_manager.add_event_pair(ht_ev, getrf_ev)
466+
467+
# getrf_batch expects `a` in Fortran order and overwrites it.
468+
# Batch was moved to the last axis before the call.
469+
# Move it back to the front and reshape to the original shape.
470+
a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape)
471+
ipiv_h = ipiv_h.reshape((*orig_shape[:-2], k))
472+
473+
if any(dev_info_h):
474+
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
475+
warn(
476+
f"Diagonal number {diag_nums} are exactly zero. "
477+
"Singular matrix.",
478+
RuntimeWarning,
479+
stacklevel=2,
480+
)
481+
482+
# MKL lapack uses 1-origin while SciPy uses 0-origin
483+
ipiv_h -= 1
484+
485+
# Return a tuple containing the factorized matrix 'a_h',
486+
# pivot indices 'ipiv_h'
487+
return (a_h, ipiv_h)
488+
489+
a_usm_arr = dpnp.get_usm_ndarray(a)
490+
491+
# Initialize lists for storing arrays and events for each batch
492+
a_vecs = [None] * batch_size
493+
ipiv_vecs = [None] * batch_size
494+
dev_info_vecs = [None] * batch_size
495+
dep_evs = _manager.submitted_events
496+
497+
# Process each batch
498+
for i in range(batch_size):
499+
# Copy each 2D slice to a new array because getrf will destroy
500+
# the input matrix
501+
a_vecs[i] = dpnp.empty_like(a[i], order="F", dtype=res_type)
502+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
503+
src=a_usm_arr[i],
504+
dst=a_vecs[i].get_array(),
505+
sycl_queue=a_sycl_queue,
506+
depends=dep_evs,
507+
)
508+
_manager.add_event_pair(ht_ev, copy_ev)
509+
510+
ipiv_vecs[i] = dpnp.empty(
511+
(k,),
512+
dtype=dpnp.int64,
513+
order="C",
514+
usm_type=a_usm_type,
515+
sycl_queue=a_sycl_queue,
516+
)
517+
518+
dev_info_vecs[i] = [0]
519+
520+
# Call the LAPACK extension function _getrf
521+
# to perform LU decomposition on each batch in 'a_vecs[i]'
522+
ht_ev, getrf_ev = li._getrf(
523+
a_sycl_queue,
524+
a_vecs[i].get_array(),
525+
ipiv_vecs[i].get_array(),
526+
dev_info_vecs[i],
527+
depends=[copy_ev],
528+
)
529+
_manager.add_event_pair(ht_ev, getrf_ev)
530+
531+
# Reshape the results back to their original shape
532+
out_a = dpnp.array(a_vecs).reshape(orig_shape)
533+
out_ipiv = dpnp.array(ipiv_vecs).reshape((*orig_shape[:-2], k))
534+
535+
diag_nums = ", ".join(
536+
str(v) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0
537+
)
538+
539+
if diag_nums:
540+
warn(
541+
f"Diagonal number {diag_nums} are exactly zero. Singular matrix.",
542+
RuntimeWarning,
543+
stacklevel=2,
544+
)
545+
546+
# MKL lapack uses 1-origin while SciPy uses 0-origin
547+
out_ipiv -= 1
548+
549+
# Return a tuple containing the factorized matrix 'out_a',
550+
# pivot indices 'out_ipiv'
551+
return (out_a, out_ipiv)
552+
553+
401554
def _batched_solve(a, b, exec_q, res_usm_type, res_type):
402555
"""
403556
_batched_solve(a, b, exec_q, res_usm_type, res_type)
@@ -2323,7 +2476,9 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23232476
raise ValueError("array must not contain infs or NaNs")
23242477

23252478
if a.ndim > 2:
2326-
raise NotImplementedError("Batched matrices are not supported")
2479+
# SciPy always copies each 2D slice,
2480+
# so `overwrite_a` is ignored here
2481+
return _batched_lu_factor_scipy(a, res_type)
23272482

23282483
_manager = dpu.SequentialOrderManager[a_sycl_queue]
23292484
a_usm_arr = dpnp.get_usm_ndarray(a)

0 commit comments

Comments
 (0)