Skip to content
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Improved performance of `dpnp.isclose` function by implementing a dedicated kernel for scalar `rtol` and `atol` arguments [#2540](https://github.com/IntelPython/dpnp/pull/2540)
* Extended `dpnp.pad` to support `pad_width` keyword as a dictionary [#2535](https://github.com/IntelPython/dpnp/pull/2535)
* Redesigned `dpnp.erf` function through pybind11 extension of OneMKL call or dedicated kernel in `ufunc` namespace [#2551](https://github.com/IntelPython/dpnp/pull/2551)
* Improved performance of batched implementation of `dpnp.linalg.det` and `dpnp.linalg.slogdet` [#2572](https://github.com/IntelPython/dpnp/pull/2572)

### Deprecated

Expand Down
103 changes: 33 additions & 70 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,26 +297,27 @@ def _batched_lu_factor(a, res_type):
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)

# `a` must be copied because getrf/getrf_batch destroys the input matrix
a_h = dpnp.empty_like(a, order="C", dtype=res_type)
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr,
dst=a_h.get_array(),
sycl_queue=a_sycl_queue,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, copy_ev)

ipiv_h = dpnp.empty(
(batch_size, n),
dtype=dpnp.int64,
order="C",
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)

if use_batch:
# `a` must be copied because getrf_batch destroys the input matrix
a_h = dpnp.empty_like(a, order="C", dtype=res_type)
ipiv_h = dpnp.empty(
(batch_size, n),
dtype=dpnp.int64,
order="C",
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)
dev_info_h = [0] * batch_size

ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr,
dst=a_h.get_array(),
sycl_queue=a_sycl_queue,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, copy_ev)

ipiv_stride = n
a_stride = a_h.strides[0]

Expand All @@ -336,63 +337,25 @@ def _batched_lu_factor(a, res_type):
)
_manager.add_event_pair(ht_ev, getrf_ev)

dev_info_array = dpnp.array(
dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue
)

# Reshape the results back to their original shape
a_h = a_h.reshape(orig_shape)
ipiv_h = ipiv_h.reshape(orig_shape[:-1])
dev_info_array = dev_info_array.reshape(orig_shape[:-2])

return (a_h, ipiv_h, dev_info_array)

# Initialize lists for storing arrays and events for each batch
a_vecs = [None] * batch_size
ipiv_vecs = [None] * batch_size
dev_info_vecs = [None] * batch_size

dep_evs = _manager.submitted_events

# Process each batch
for i in range(batch_size):
# Copy each 2D slice to a new array because getrf will destroy
# the input matrix
a_vecs[i] = dpnp.empty_like(a[i], order="C", dtype=res_type)

ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr[i],
dst=a_vecs[i].get_array(),
sycl_queue=a_sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, copy_ev)

ipiv_vecs[i] = dpnp.empty(
(n,),
dtype=dpnp.int64,
order="C",
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)
dev_info_vecs[i] = [0]
else:
dev_info_h = [[0] for _ in range(batch_size)]

# Call the LAPACK extension function _getrf
# to perform LU decomposition on each batch in 'a_vecs[i]'
ht_ev, getrf_ev = li._getrf(
a_sycl_queue,
a_vecs[i].get_array(),
ipiv_vecs[i].get_array(),
dev_info_vecs[i],
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, getrf_ev)
# Sequential LU factorization using getrf per slice
for i in range(batch_size):
ht_ev, getrf_ev = li._getrf(
a_sycl_queue,
a_h[i].get_array(),
ipiv_h[i].get_array(),
dev_info_h[i],
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, getrf_ev)

# Reshape the results back to their original shape
out_a = dpnp.array(a_vecs, order="C").reshape(orig_shape)
out_ipiv = dpnp.array(ipiv_vecs).reshape(orig_shape[:-1])
out_a = a_h.reshape(orig_shape)
out_ipiv = ipiv_h.reshape(orig_shape[:-1])
out_dev_info = dpnp.array(
dev_info_vecs, usm_type=a_usm_type, sycl_queue=a_sycl_queue
dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue
).reshape(orig_shape[:-2])

return (out_a, out_ipiv, out_dev_info)
Expand Down
Loading