Skip to content

Commit d975818

Browse files
Implement dpnp.linalg.lu_factor batch inputs (#2565)
This PR suggests extending `dpnp.linalg.lu_factor()` #2557 for batch arrays In addition, this PR includes: An updated implementation of getrf_batch to support non-square matrices
1 parent 6eb6f6f commit d975818

File tree

7 files changed

+279
-26
lines changed

7 files changed

+279
-26
lines changed

dpnp/backend/extensions/lapack/getrf.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ extern std::pair<sycl::event, sycl::event>
4444
const dpctl::tensor::usm_ndarray &a_array,
4545
const dpctl::tensor::usm_ndarray &ipiv_array,
4646
py::list dev_info,
47+
std::int64_t m,
4748
std::int64_t n,
4849
std::int64_t stride_a,
4950
std::int64_t stride_ipiv,

dpnp/backend/extensions/lapack/getrf_batch.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ namespace type_utils = dpctl::tensor::type_utils;
4646
typedef sycl::event (*getrf_batch_impl_fn_ptr_t)(
4747
sycl::queue &,
4848
std::int64_t,
49+
std::int64_t,
4950
char *,
5051
std::int64_t,
5152
std::int64_t,
@@ -61,6 +62,7 @@ static getrf_batch_impl_fn_ptr_t
6162

6263
template <typename T>
6364
static sycl::event getrf_batch_impl(sycl::queue &exec_q,
65+
std::int64_t m,
6466
std::int64_t n,
6567
char *in_a,
6668
std::int64_t lda,
@@ -77,7 +79,7 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,
7779
T *a = reinterpret_cast<T *>(in_a);
7880

7981
const std::int64_t scratchpad_size =
80-
mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, n, n, lda, stride_a,
82+
mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, m, n, lda, stride_a,
8183
stride_ipiv, batch_size);
8284
T *scratchpad = nullptr;
8385

@@ -91,11 +93,11 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,
9193

9294
getrf_batch_event = mkl_lapack::getrf_batch(
9395
exec_q,
94-
n, // The order of each square matrix in the batch; (0 ≤ n).
96+
m, // The number of rows in each matrix in the batch; (0 ≤ m).
9597
// It must be a non-negative integer.
9698
n, // The number of columns in each matrix in the batch; (0 ≤ n).
9799
// It must be a non-negative integer.
98-
a, // Pointer to the batch of square matrices, each of size (n x n).
100+
a, // Pointer to the batch of input matrices, each of size (m x n).
99101
lda, // The leading dimension of each matrix in the batch.
100102
stride_a, // Stride between consecutive matrices in the batch.
101103
ipiv, // Pointer to the array of pivot indices for each matrix in
@@ -179,6 +181,7 @@ std::pair<sycl::event, sycl::event>
179181
const dpctl::tensor::usm_ndarray &a_array,
180182
const dpctl::tensor::usm_ndarray &ipiv_array,
181183
py::list dev_info,
184+
std::int64_t m,
182185
std::int64_t n,
183186
std::int64_t stride_a,
184187
std::int64_t stride_ipiv,
@@ -191,21 +194,21 @@ std::pair<sycl::event, sycl::event>
191194
if (a_array_nd < 3) {
192195
throw py::value_error(
193196
"The input array has ndim=" + std::to_string(a_array_nd) +
194-
", but an array with ndim >= 3 is expected.");
197+
", but an array with ndim >= 3 is expected");
195198
}
196199

197200
if (ipiv_array_nd != 2) {
198201
throw py::value_error("The array of pivot indices has ndim=" +
199202
std::to_string(ipiv_array_nd) +
200-
", but a 2-dimensional array is expected.");
203+
", but a 2-dimensional array is expected");
201204
}
202205

203206
const int dev_info_size = py::len(dev_info);
204207
if (dev_info_size != batch_size) {
205208
throw py::value_error("The size of 'dev_info' (" +
206209
std::to_string(dev_info_size) +
207210
") does not match the expected batch size (" +
208-
std::to_string(batch_size) + ").");
211+
std::to_string(batch_size) + ")");
209212
}
210213

211214
// check compatibility of execution queue and allocation queue
@@ -241,27 +244,34 @@ std::pair<sycl::event, sycl::event>
241244
if (getrf_batch_fn == nullptr) {
242245
throw py::value_error(
243246
"No getrf_batch implementation defined for the provided type "
244-
"of the input matrix.");
247+
"of the input matrix");
245248
}
246249

247250
auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
248251
int ipiv_array_type_id =
249252
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
250253

251254
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
252-
throw py::value_error("The type of 'ipiv_array' must be int64.");
255+
throw py::value_error("The type of 'ipiv_array' must be int64");
256+
}
257+
258+
const py::ssize_t *ipiv_array_shape = ipiv_array.get_shape_raw();
259+
if (ipiv_array_shape[0] != batch_size ||
260+
ipiv_array_shape[1] != std::min(m, n)) {
261+
throw py::value_error(
262+
"The shape of 'ipiv_array' must be (batch_size, min(m, n))");
253263
}
254264

255265
char *a_array_data = a_array.get_data();
256-
const std::int64_t lda = std::max<size_t>(1UL, n);
266+
const std::int64_t lda = std::max<size_t>(1UL, m);
257267

258268
char *ipiv_array_data = ipiv_array.get_data();
259269
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);
260270

261271
std::vector<sycl::event> host_task_events;
262272
sycl::event getrf_batch_ev = getrf_batch_fn(
263-
exec_q, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, batch_size,
264-
dev_info, host_task_events, depends);
273+
exec_q, m, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv,
274+
batch_size, dev_info, host_task_events, depends);
265275

266276
sycl::event args_ev = dpctl::utils::keep_args_alive(
267277
exec_q, {a_array, ipiv_array}, host_task_events);

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ PYBIND11_MODULE(_lapack_impl, m)
141141

142142
m.def("_getrf_batch", &lapack_ext::getrf_batch,
143143
"Call `getrf_batch` from OneMKL LAPACK library to return "
144-
"the LU factorization of a batch of general n x n matrices",
144+
"the LU factorization of a batch of general m x n matrices",
145145
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
146-
py::arg("dev_info_array"), py::arg("n"), py::arg("stride_a"),
147-
py::arg("stride_ipiv"), py::arg("batch_size"),
146+
py::arg("dev_info_array"), py::arg("m"), py::arg("n"),
147+
py::arg("stride_a"), py::arg("stride_ipiv"), py::arg("batch_size"),
148148
py::arg("depends") = py::list());
149149

150150
m.def("_getri_batch", &lapack_ext::getri_batch,

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 136 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def _batched_inv(a, res_type):
246246
ipiv_h.get_array(),
247247
dev_info,
248248
n,
249+
n,
249250
a_stride,
250251
ipiv_stride,
251252
batch_size,
@@ -327,6 +328,7 @@ def _batched_lu_factor(a, res_type):
327328
ipiv_h.get_array(),
328329
dev_info_h,
329330
n,
331+
n,
330332
a_stride,
331333
ipiv_stride,
332334
batch_size,
@@ -396,6 +398,131 @@ def _batched_lu_factor(a, res_type):
396398
return (out_a, out_ipiv, out_dev_info)
397399

398400

401+
def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
402+
"""SciPy-compatible LU factorization for batched inputs."""
403+
404+
# TODO: Find out at which array sizes the best performance is obtained
405+
# getrf_batch can be slow on large GPU arrays.
406+
# Use getrf_batch only on CPU.
407+
# On GPU fall back to calling getrf per 2D slice.
408+
use_batch = a.sycl_device.has_aspect_cpu
409+
410+
a_sycl_queue = a.sycl_queue
411+
a_usm_type = a.usm_type
412+
_manager = dpu.SequentialOrderManager[a_sycl_queue]
413+
414+
m, n = a.shape[-2:]
415+
k = min(m, n)
416+
orig_shape = a.shape
417+
batch_shape = orig_shape[:-2]
418+
419+
# handle empty input
420+
if a.size == 0:
421+
lu = dpnp.empty_like(a)
422+
piv = dpnp.empty(
423+
(*batch_shape, k),
424+
dtype=dpnp.int64,
425+
usm_type=a_usm_type,
426+
sycl_queue=a_sycl_queue,
427+
)
428+
return lu, piv
429+
430+
# get 3d input arrays by reshape
431+
a = dpnp.reshape(a, (-1, m, n))
432+
batch_size = a.shape[0]
433+
434+
# Move batch axis to the end (m, n, batch) in Fortran order:
435+
# required by getrf_batch
436+
# and ensures each a[..., i] is F-contiguous for getrf
437+
a = dpnp.moveaxis(a, 0, -1)
438+
439+
a_usm_arr = dpnp.get_usm_ndarray(a)
440+
441+
# `a` must be copied because getrf/getrf_batch destroys the input matrix
442+
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
443+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
444+
src=a_usm_arr,
445+
dst=a_h.get_array(),
446+
sycl_queue=a_sycl_queue,
447+
depends=_manager.submitted_events,
448+
)
449+
_manager.add_event_pair(ht_ev, copy_ev)
450+
451+
ipiv_h = dpnp.empty(
452+
(batch_size, k),
453+
dtype=dpnp.int64,
454+
order="C",
455+
usm_type=a_usm_type,
456+
sycl_queue=a_sycl_queue,
457+
)
458+
459+
if use_batch:
460+
dev_info_h = [0] * batch_size
461+
462+
ipiv_stride = k
463+
a_stride = a_h.strides[-1]
464+
465+
# Call the LAPACK extension function _getrf_batch
466+
# to perform LU decomposition of a batch of general matrices
467+
ht_ev, getrf_ev = li._getrf_batch(
468+
a_sycl_queue,
469+
a_h.get_array(),
470+
ipiv_h.get_array(),
471+
dev_info_h,
472+
m,
473+
n,
474+
a_stride,
475+
ipiv_stride,
476+
batch_size,
477+
depends=[copy_ev],
478+
)
479+
_manager.add_event_pair(ht_ev, getrf_ev)
480+
481+
if any(dev_info_h):
482+
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
483+
warn(
484+
f"Diagonal number {diag_nums} are exactly zero. "
485+
"Singular matrix.",
486+
RuntimeWarning,
487+
stacklevel=2,
488+
)
489+
else:
490+
dev_info_vecs = [[0] for _ in range(batch_size)]
491+
492+
# Sequential LU factorization using getrf per slice
493+
for i in range(batch_size):
494+
ht_ev, getrf_ev = li._getrf(
495+
a_sycl_queue,
496+
a_h[..., i].get_array(),
497+
ipiv_h[i].get_array(),
498+
dev_info_vecs[i],
499+
depends=[copy_ev],
500+
)
501+
_manager.add_event_pair(ht_ev, getrf_ev)
502+
503+
diag_nums = ", ".join(
504+
str(v) for info in dev_info_vecs for v in info if v > 0
505+
)
506+
if diag_nums:
507+
warn(
508+
f"Diagonal number {diag_nums} are exactly zero. "
509+
"Singular matrix.",
510+
RuntimeWarning,
511+
stacklevel=2,
512+
)
513+
514+
# Restore original shape: move batch axis back and reshape
515+
a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape)
516+
ipiv_h = ipiv_h.reshape((*batch_shape, k))
517+
518+
# oneMKL LAPACK uses 1-origin while SciPy uses 0-origin
519+
ipiv_h -= 1
520+
521+
# Return a tuple containing the factorized matrix 'a_h',
522+
# pivot indices 'ipiv_h'
523+
return (a_h, ipiv_h)
524+
525+
399526
def _batched_solve(a, b, exec_q, res_usm_type, res_type):
400527
"""
401528
_batched_solve(a, b, exec_q, res_usm_type, res_type)
@@ -2308,6 +2435,15 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23082435
a_sycl_queue = a.sycl_queue
23092436
a_usm_type = a.usm_type
23102437

2438+
if check_finite:
2439+
if not dpnp.isfinite(a).all():
2440+
raise ValueError("array must not contain infs or NaNs")
2441+
2442+
if a.ndim > 2:
2443+
# SciPy always copies each 2D slice,
2444+
# so `overwrite_a` is ignored here
2445+
return _batched_lu_factor_scipy(a, res_type)
2446+
23112447
# accommodate empty arrays
23122448
if a.size == 0:
23132449
lu = dpnp.empty_like(a)
@@ -2316,13 +2452,6 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23162452
)
23172453
return lu, piv
23182454

2319-
if check_finite:
2320-
if not dpnp.isfinite(a).all():
2321-
raise ValueError("array must not contain infs or NaNs")
2322-
2323-
if a.ndim > 2:
2324-
raise NotImplementedError("Batched matrices are not supported")
2325-
23262455
_manager = dpu.SequentialOrderManager[a_sycl_queue]
23272456
a_usm_arr = dpnp.get_usm_ndarray(a)
23282457

0 commit comments

Comments
 (0)