From 866c29df8ead0cf6893a7030c1fa782e64d7b0f5 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 5 Aug 2025 02:59:40 -0700 Subject: [PATCH 01/12] Implement dpnp_lu_factor() and update _lu_factor() --- dpnp/linalg/dpnp_utils_linalg.py | 90 ++++++++++++++++++++++++++------ 1 file changed, 74 insertions(+), 16 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 838daac66303..ffd0ee29320b 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -1009,7 +1009,7 @@ def _is_empty_2d(arr): return arr.size == 0 and numpy.prod(arr.shape[-2:]) == 0 -def _lu_factor(a, res_type): +def _lu_factor(a, res_type, scipy=False, overwrite_a=False): """ Compute pivoted LU decomposition. @@ -1050,18 +1050,41 @@ def _lu_factor(a, res_type): a_usm_arr = dpnp.get_usm_ndarray(a) - # `a` must be copied because getrf destroys the input matrix - a_h = dpnp.empty_like(a, order="C", dtype=res_type) + if not scipy: + # Internal use case (e.g., det(), slogdet()). Always copy. + # `a` must be copied because getrf destroys the input matrix + a_h = dpnp.empty_like(a, order="C", dtype=res_type) - # use DPCTL tensor function to fill the сopy of the input array - # from the input array - ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, - dst=a_h.get_array(), - sycl_queue=a_sycl_queue, - depends=_manager.submitted_events, - ) - _manager.add_event_pair(ht_ev, copy_ev) + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_h.get_array(), + sycl_queue=a_sycl_queue, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, copy_ev) + + else: + # SciPy-compatible behavior + # Copy is required if: + # - overwrite_a is False (always copy), + # - dtype mismatch, + # - not F-contiguous, + # - not writeable + if not overwrite_a or _is_copy_required(a, res_type): + a_h = dpnp.empty_like(a, order="F", dtype=res_type) + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_h.get_array(), + sycl_queue=a_sycl_queue, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, copy_ev) + else: + # input is suitable for in-place modification + a_h = a + copy_ev = None ipiv_h = dpnp.empty( n, @@ -1079,13 +1102,18 @@ def _lu_factor(a, res_type): a_h.get_array(), ipiv_h.get_array(), dev_info_h, - depends=[copy_ev], + depends=[copy_ev] if copy_ev is not None else [], ) _manager.add_event_pair(ht_ev, getrf_ev) - dev_info_array = dpnp.array( - dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue - ) + # Return list if called in SciPy-compatible mode + # else dpnp.ndarray + if scipy: + dev_info_array = dev_info_h + else: + dev_info_array = dpnp.array( + dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue + ) # Return a tuple containing the factorized matrix 'a_h', # pivot indices 'ipiv_h' @@ -1093,6 +1121,36 @@ def _lu_factor(a, res_type): return (a_h, ipiv_h, dev_info_array) +def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): + """Compute pivoted LU decomposition.""" + + res_type = _common_type(a) + + # accommodate empty arrays + if a.size == 0: + lu = dpnp.empty_like(a) + piv = dpnp.arange(0, dtype=dpnp.int32) + return lu, piv + + if check_finite: + if not dpnp.isfinite(a).all(): + raise ValueError("array must not contain infs or NaNs") + + lu, piv, dev_info = _lu_factor( + a, res_type, scipy=True, overwrite_a=overwrite_a + ) + + if any(dev_info): + diag_nums = ", ".join(str(v) for v in dev_info if v > 0) + warn( + f"Diagonal number {diag_nums} are exactly zero. Singular matrix.", + RuntimeWarning, + stacklevel=2, + ) + + return lu, piv + + def _multi_dot(arrays, order, i, j, out=None): """Actually do the multiplication with the given order.""" From 02b645e8500fb7635756a2f01929077a08e10e4a Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 5 Aug 2025 05:00:41 -0700 Subject: [PATCH 02/12] Simplify lu_factor by moving logic to dpnp_lu_factor --- dpnp/linalg/dpnp_utils_linalg.py | 90 ++++++-------------------------- 1 file changed, 16 insertions(+), 74 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index ffd0ee29320b..838daac66303 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -1009,7 +1009,7 @@ def _is_empty_2d(arr): return arr.size == 0 and numpy.prod(arr.shape[-2:]) == 0 -def _lu_factor(a, res_type, scipy=False, overwrite_a=False): +def _lu_factor(a, res_type): """ Compute pivoted LU decomposition. @@ -1050,41 +1050,18 @@ def _lu_factor(a, res_type, scipy=False, overwrite_a=False): a_usm_arr = dpnp.get_usm_ndarray(a) - if not scipy: - # Internal use case (e.g., det(), slogdet()). Always copy. - # `a` must be copied because getrf destroys the input matrix - a_h = dpnp.empty_like(a, order="C", dtype=res_type) + # `a` must be copied because getrf destroys the input matrix + a_h = dpnp.empty_like(a, order="C", dtype=res_type) - # use DPCTL tensor function to fill the сopy of the input array - # from the input array - ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, - dst=a_h.get_array(), - sycl_queue=a_sycl_queue, - depends=_manager.submitted_events, - ) - _manager.add_event_pair(ht_ev, copy_ev) - - else: - # SciPy-compatible behavior - # Copy is required if: - # - overwrite_a is False (always copy), - # - dtype mismatch, - # - not F-contiguous, - # - not writeable - if not overwrite_a or _is_copy_required(a, res_type): - a_h = dpnp.empty_like(a, order="F", dtype=res_type) - ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, - dst=a_h.get_array(), - sycl_queue=a_sycl_queue, - depends=_manager.submitted_events, - ) - _manager.add_event_pair(ht_ev, copy_ev) - else: - # input is suitable for in-place modification - a_h = a - copy_ev = None + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_h.get_array(), + sycl_queue=a_sycl_queue, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, copy_ev) ipiv_h = dpnp.empty( n, @@ -1102,18 +1079,13 @@ def _lu_factor(a, res_type, scipy=False, overwrite_a=False): a_h.get_array(), ipiv_h.get_array(), dev_info_h, - depends=[copy_ev] if copy_ev is not None else [], + depends=[copy_ev], ) _manager.add_event_pair(ht_ev, getrf_ev) - # Return list if called in SciPy-compatible mode - # else dpnp.ndarray - if scipy: - dev_info_array = dev_info_h - else: - dev_info_array = dpnp.array( - dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue - ) + dev_info_array = dpnp.array( + dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue + ) # Return a tuple containing the factorized matrix 'a_h', # pivot indices 'ipiv_h' @@ -1121,36 +1093,6 @@ def _lu_factor(a, res_type, scipy=False, overwrite_a=False): return (a_h, ipiv_h, dev_info_array) -def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): - """Compute pivoted LU decomposition.""" - - res_type = _common_type(a) - - # accommodate empty arrays - if a.size == 0: - lu = dpnp.empty_like(a) - piv = dpnp.arange(0, dtype=dpnp.int32) - return lu, piv - - if check_finite: - if not dpnp.isfinite(a).all(): - raise ValueError("array must not contain infs or NaNs") - - lu, piv, dev_info = _lu_factor( - a, res_type, scipy=True, overwrite_a=overwrite_a - ) - - if any(dev_info): - diag_nums = ", ".join(str(v) for v in dev_info if v > 0) - warn( - f"Diagonal number {diag_nums} are exactly zero. Singular matrix.", - RuntimeWarning, - stacklevel=2, - ) - - return lu, piv - - def _multi_dot(arrays, order, i, j, out=None): """Actually do the multiplication with the given order.""" From df94dfff466eaf2ac9a748c417c54173c3d5e1b6 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 13 Aug 2025 03:38:05 -0700 Subject: [PATCH 03/12] Add TestLuFactor --- dpnp/tests/test_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index dd5daa99b74d..15f452753fdf 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -15,6 +15,7 @@ ) import dpnp +import dpnp.linalg from .helper import ( assert_dtype_allclose, From 0f51521ff4551dc76807ebf53808c228f60e0e42 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 13 Aug 2025 05:54:55 -0700 Subject: [PATCH 04/12] Add test_overwrite_copy_special to improve coverage --- dpnp/tests/test_linalg.py | 48 +++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 15f452753fdf..3797d0dbc124 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -278,12 +278,15 @@ def test_cholesky_errors(self): class TestCond: - _norms = [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] + def setup_method(self): + numpy.random.seed(70) @pytest.mark.parametrize( - "shape", [(0, 4, 4), (4, 0, 3, 3)], ids=["(0, 4, 4)", "(4, 0, 3, 3)"] + "shape", [(0, 4, 4), (4, 0, 3, 3)], ids=["(0, 5, 3)", "(4, 0, 2, 3)"] + ) + @pytest.mark.parametrize( + "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] ) - @pytest.mark.parametrize("p", _norms) def test_empty(self, shape, p): a = numpy.empty(shape) ia = dpnp.array(a) @@ -292,27 +295,26 @@ def test_empty(self, shape, p): expected = numpy.linalg.cond(a, p=p) assert_dtype_allclose(result, expected) - # TODO: uncomment once numpy 2.3.3 release is published - # @testing.with_requires("numpy>=2.3.3") @pytest.mark.parametrize( "dtype", get_all_dtypes(no_none=True, no_bool=True) ) @pytest.mark.parametrize( "shape", [(4, 4), (2, 4, 3, 3)], ids=["(4, 4)", "(2, 4, 3, 3)"] ) - @pytest.mark.parametrize("p", _norms) + @pytest.mark.parametrize( + "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] + ) def test_basic(self, dtype, shape, p): a = generate_random_numpy_array(shape, dtype) ia = dpnp.array(a) result = dpnp.linalg.cond(ia, p=p) expected = numpy.linalg.cond(a, p=p) - # TODO: remove when numpy#29333 is released - if numpy_version() < "2.3.3": - expected = expected.real assert_dtype_allclose(result, expected, factor=16) - @pytest.mark.parametrize("p", _norms) + @pytest.mark.parametrize( + "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] + ) def test_bool(self, p): a = numpy.array([[True, True], [True, False]]) ia = dpnp.array(a) @@ -321,7 +323,9 @@ def test_bool(self, p): expected = numpy.linalg.cond(a, p=p) assert_dtype_allclose(result, expected) - @pytest.mark.parametrize("p", _norms) + @pytest.mark.parametrize( + "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] + ) def test_nan_to_inf(self, p): a = numpy.zeros((2, 2)) ia = dpnp.array(a) @@ -339,7 +343,9 @@ def test_nan_to_inf(self, p): else: assert_raises(dpnp.linalg.LinAlgError, dpnp.linalg.cond, ia, p=p) - @pytest.mark.parametrize("p", _norms) + @pytest.mark.parametrize( + "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] + ) @pytest.mark.parametrize( "stride", [(-2, -3, 2, -2), (-2, 4, -4, -4), (2, 3, 4, 4), (-1, 3, 3, -3)], @@ -351,23 +357,21 @@ def test_nan_to_inf(self, p): ], ) def test_strided(self, p, stride): - A = generate_random_numpy_array( - (6, 8, 10, 10), seed_value=70, low=0, high=1 - ) - iA = dpnp.array(A) + A = numpy.random.rand(6, 8, 10, 10) + B = dpnp.asarray(A) slices = tuple(slice(None, None, stride[i]) for i in range(A.ndim)) - a, ia = A[slices], iA[slices] + a = A[slices] + b = B[slices] - result = dpnp.linalg.cond(ia, p=p) + result = dpnp.linalg.cond(b, p=p) expected = numpy.linalg.cond(a, p=p) assert_dtype_allclose(result, expected, factor=24) - @pytest.mark.parametrize("xp", [dpnp, numpy]) - def test_error(self, xp): + def test_error(self): # cond is not defined on empty arrays - a = xp.empty((2, 0)) + ia = dpnp.empty((2, 0)) with pytest.raises(ValueError): - xp.linalg.cond(a, p=1) + dpnp.linalg.cond(ia, p=1) class TestDet: From 86f5073c92a21915528320d2e43b8e407e2c9143 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 13 Aug 2025 07:50:00 -0700 Subject: [PATCH 05/12] Align test_linalg.py with master --- dpnp/tests/test_linalg.py | 49 ++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 3797d0dbc124..dd5daa99b74d 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -15,7 +15,6 @@ ) import dpnp -import dpnp.linalg from .helper import ( assert_dtype_allclose, @@ -278,15 +277,12 @@ def test_cholesky_errors(self): class TestCond: - def setup_method(self): - numpy.random.seed(70) + _norms = [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] @pytest.mark.parametrize( - "shape", [(0, 4, 4), (4, 0, 3, 3)], ids=["(0, 5, 3)", "(4, 0, 2, 3)"] - ) - @pytest.mark.parametrize( - "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] + "shape", [(0, 4, 4), (4, 0, 3, 3)], ids=["(0, 4, 4)", "(4, 0, 3, 3)"] ) + @pytest.mark.parametrize("p", _norms) def test_empty(self, shape, p): a = numpy.empty(shape) ia = dpnp.array(a) @@ -295,26 +291,27 @@ def test_empty(self, shape, p): expected = numpy.linalg.cond(a, p=p) assert_dtype_allclose(result, expected) + # TODO: uncomment once numpy 2.3.3 release is published + # @testing.with_requires("numpy>=2.3.3") @pytest.mark.parametrize( "dtype", get_all_dtypes(no_none=True, no_bool=True) ) @pytest.mark.parametrize( "shape", [(4, 4), (2, 4, 3, 3)], ids=["(4, 4)", "(2, 4, 3, 3)"] ) - @pytest.mark.parametrize( - "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] - ) + @pytest.mark.parametrize("p", _norms) def test_basic(self, dtype, shape, p): a = generate_random_numpy_array(shape, dtype) ia = dpnp.array(a) result = dpnp.linalg.cond(ia, p=p) expected = numpy.linalg.cond(a, p=p) + # TODO: remove when numpy#29333 is released + if numpy_version() < "2.3.3": + expected = expected.real assert_dtype_allclose(result, expected, factor=16) - @pytest.mark.parametrize( - "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] - ) + @pytest.mark.parametrize("p", _norms) def test_bool(self, p): a = numpy.array([[True, True], [True, False]]) ia = dpnp.array(a) @@ -323,9 +320,7 @@ def test_bool(self, p): expected = numpy.linalg.cond(a, p=p) assert_dtype_allclose(result, expected) - @pytest.mark.parametrize( - "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] - ) + @pytest.mark.parametrize("p", _norms) def test_nan_to_inf(self, p): a = numpy.zeros((2, 2)) ia = dpnp.array(a) @@ -343,9 +338,7 @@ def test_nan_to_inf(self, p): else: assert_raises(dpnp.linalg.LinAlgError, dpnp.linalg.cond, ia, p=p) - @pytest.mark.parametrize( - "p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"] - ) + @pytest.mark.parametrize("p", _norms) @pytest.mark.parametrize( "stride", [(-2, -3, 2, -2), (-2, 4, -4, -4), (2, 3, 4, 4), (-1, 3, 3, -3)], @@ -357,21 +350,23 @@ def test_nan_to_inf(self, p): ], ) def test_strided(self, p, stride): - A = numpy.random.rand(6, 8, 10, 10) - B = dpnp.asarray(A) + A = generate_random_numpy_array( + (6, 8, 10, 10), seed_value=70, low=0, high=1 + ) + iA = dpnp.array(A) slices = tuple(slice(None, None, stride[i]) for i in range(A.ndim)) - a = A[slices] - b = B[slices] + a, ia = A[slices], iA[slices] - result = dpnp.linalg.cond(b, p=p) + result = dpnp.linalg.cond(ia, p=p) expected = numpy.linalg.cond(a, p=p) assert_dtype_allclose(result, expected, factor=24) - def test_error(self): + @pytest.mark.parametrize("xp", [dpnp, numpy]) + def test_error(self, xp): # cond is not defined on empty arrays - ia = dpnp.empty((2, 0)) + a = xp.empty((2, 0)) with pytest.raises(ValueError): - dpnp.linalg.cond(ia, p=1) + xp.linalg.cond(a, p=1) class TestDet: From b6943d78ce4de98c62ae3b5b7908fcda707cb10c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 20 Aug 2025 09:40:30 -0700 Subject: [PATCH 06/12] Extend getrf_batch to support non-square matrices --- dpnp/backend/extensions/lapack/getrf.hpp | 1 + .../backend/extensions/lapack/getrf_batch.cpp | 32 ++++++++++++------- dpnp/backend/extensions/lapack/lapack_py.cpp | 6 ++-- dpnp/linalg/dpnp_utils_linalg.py | 2 ++ 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/dpnp/backend/extensions/lapack/getrf.hpp b/dpnp/backend/extensions/lapack/getrf.hpp index 5fd9ecdcc499..952b244ef132 100644 --- a/dpnp/backend/extensions/lapack/getrf.hpp +++ b/dpnp/backend/extensions/lapack/getrf.hpp @@ -44,6 +44,7 @@ extern std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, py::list dev_info, + std::int64_t m, std::int64_t n, std::int64_t stride_a, std::int64_t stride_ipiv, diff --git a/dpnp/backend/extensions/lapack/getrf_batch.cpp b/dpnp/backend/extensions/lapack/getrf_batch.cpp index ec87c8b1f2ae..446f565d6e49 100644 --- a/dpnp/backend/extensions/lapack/getrf_batch.cpp +++ b/dpnp/backend/extensions/lapack/getrf_batch.cpp @@ -46,6 +46,7 @@ namespace type_utils = dpctl::tensor::type_utils; typedef sycl::event (*getrf_batch_impl_fn_ptr_t)( sycl::queue &, std::int64_t, + std::int64_t, char *, std::int64_t, std::int64_t, @@ -61,6 +62,7 @@ static getrf_batch_impl_fn_ptr_t template static sycl::event getrf_batch_impl(sycl::queue &exec_q, + std::int64_t m, std::int64_t n, char *in_a, std::int64_t lda, @@ -77,7 +79,7 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q, T *a = reinterpret_cast(in_a); const std::int64_t scratchpad_size = - mkl_lapack::getrf_batch_scratchpad_size(exec_q, n, n, lda, stride_a, + mkl_lapack::getrf_batch_scratchpad_size(exec_q, m, n, lda, stride_a, stride_ipiv, batch_size); T *scratchpad = nullptr; @@ -91,11 +93,11 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q, getrf_batch_event = mkl_lapack::getrf_batch( exec_q, - n, // The order of each square matrix in the batch; (0 ≤ n). + m, // The number of rows in each matrix in the batch; (0 ≤ m). // It must be a non-negative integer. n, // The number of columns in each matrix in the batch; (0 ≤ n). // It must be a non-negative integer. - a, // Pointer to the batch of square matrices, each of size (n x n). + a, // Pointer to the batch of input matrices, each of size (m x n). lda, // The leading dimension of each matrix in the batch. stride_a, // Stride between consecutive matrices in the batch. ipiv, // Pointer to the array of pivot indices for each matrix in @@ -179,6 +181,7 @@ std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, py::list dev_info, + std::int64_t m, std::int64_t n, std::int64_t stride_a, std::int64_t stride_ipiv, @@ -191,13 +194,13 @@ std::pair if (a_array_nd < 3) { throw py::value_error( "The input array has ndim=" + std::to_string(a_array_nd) + - ", but an array with ndim >= 3 is expected."); + ", but an array with ndim >= 3 is expected"); } if (ipiv_array_nd != 2) { throw py::value_error("The array of pivot indices has ndim=" + std::to_string(ipiv_array_nd) + - ", but a 2-dimensional array is expected."); + ", but a 2-dimensional array is expected"); } const int dev_info_size = py::len(dev_info); @@ -205,7 +208,7 @@ std::pair throw py::value_error("The size of 'dev_info' (" + std::to_string(dev_info_size) + ") does not match the expected batch size (" + - std::to_string(batch_size) + ")."); + std::to_string(batch_size) + ")"); } // check compatibility of execution queue and allocation queue @@ -241,7 +244,7 @@ std::pair if (getrf_batch_fn == nullptr) { throw py::value_error( "No getrf_batch implementation defined for the provided type " - "of the input matrix."); + "of the input matrix"); } auto ipiv_types = dpctl_td_ns::usm_ndarray_types(); @@ -249,19 +252,26 @@ std::pair ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); if (ipiv_array_type_id != static_cast(dpctl_td_ns::typenum_t::INT64)) { - throw py::value_error("The type of 'ipiv_array' must be int64."); + throw py::value_error("The type of 'ipiv_array' must be int64"); + } + + const py::ssize_t *ipiv_array_shape = ipiv_array.get_shape_raw(); + if (ipiv_array_shape[0] != batch_size || + ipiv_array_shape[1] != std::min(m, n)) { + throw py::value_error( + "The shape of 'ipiv_array' must be (batch_size, min(m, n))"); } char *a_array_data = a_array.get_data(); - const std::int64_t lda = std::max(1UL, n); + const std::int64_t lda = std::max(1UL, m); char *ipiv_array_data = ipiv_array.get_data(); std::int64_t *d_ipiv = reinterpret_cast(ipiv_array_data); std::vector host_task_events; sycl::event getrf_batch_ev = getrf_batch_fn( - exec_q, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, batch_size, - dev_info, host_task_events, depends); + exec_q, m, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, + batch_size, dev_info, host_task_events, depends); sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {a_array, ipiv_array}, host_task_events); diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index fb4dce4643b7..83a0555f808b 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -141,10 +141,10 @@ PYBIND11_MODULE(_lapack_impl, m) m.def("_getrf_batch", &lapack_ext::getrf_batch, "Call `getrf_batch` from OneMKL LAPACK library to return " - "the LU factorization of a batch of general n x n matrices", + "the LU factorization of a batch of general m x n matrices", py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), - py::arg("dev_info_array"), py::arg("n"), py::arg("stride_a"), - py::arg("stride_ipiv"), py::arg("batch_size"), + py::arg("dev_info_array"), py::arg("m"), py::arg("n"), + py::arg("stride_a"), py::arg("stride_ipiv"), py::arg("batch_size"), py::arg("depends") = py::list()); m.def("_getri_batch", &lapack_ext::getri_batch, diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 838daac66303..3b57c1f12471 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -246,6 +246,7 @@ def _batched_inv(a, res_type): ipiv_h.get_array(), dev_info, n, + n, a_stride, ipiv_stride, batch_size, @@ -327,6 +328,7 @@ def _batched_lu_factor(a, res_type): ipiv_h.get_array(), dev_info_h, n, + n, a_stride, ipiv_stride, batch_size, From 7cb23dbcb95b76b4c5ff8940e3665ec9bfe528b8 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 21 Aug 2025 03:26:04 -0700 Subject: [PATCH 07/12] Extend dpnp.linalg.lu_factor to support batched inputs --- dpnp/linalg/dpnp_utils_linalg.py | 157 ++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 1 deletion(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 3b57c1f12471..92e4f45124fd 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -398,6 +398,159 @@ def _batched_lu_factor(a, res_type): return (out_a, out_ipiv, out_dev_info) +def _batched_lu_factor_scipy(a, res_type): + """SciPy-compatible LU factorization for batched inputs.""" + + # TODO: Find out at which array sizes the best performance is obtained + # getrf_batch implementation shows slow results with large arrays on GPU. + # Use getrf_batch only on CPU. + # On GPU call getrf for each two-dimensional array by loop + use_batch = a.sycl_device.has_aspect_cpu + + a_sycl_queue = a.sycl_queue + a_usm_type = a.usm_type + _manager = dpu.SequentialOrderManager[a_sycl_queue] + + m, n = a.shape[-2:] + k = min(m, n) + orig_shape = a.shape + # get 3d input arrays by reshape + a = dpnp.reshape(a, (-1, m, n)) + batch_size = a.shape[0] + + if use_batch: + # Reorder the elements by moving the last two axes of `a` to the front + # to match fortran-like array order which is assumed by getrf_batch + a = dpnp.moveaxis(a, 0, -1) + + a_usm_arr = dpnp.get_usm_ndarray(a) + + # `a` must be copied because getrf_batch destroys the input matrix + a_h = dpnp.empty_like(a, order="F", dtype=res_type) + ipiv_h = dpnp.empty( + (batch_size, k), + dtype=dpnp.int64, + order="C", + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + + dev_info_h = [0] * batch_size + + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_h.get_array(), + sycl_queue=a_sycl_queue, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, copy_ev) + + ipiv_stride = k + a_stride = a_h.strides[-1] + + # Call the LAPACK extension function _getrf_batch + # to perform LU decomposition of a batch of general matrices + ht_ev, getrf_ev = li._getrf_batch( + a_sycl_queue, + a_h.get_array(), + ipiv_h.get_array(), + dev_info_h, + m, + n, + a_stride, + ipiv_stride, + batch_size, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, getrf_ev) + + # getrf_batch expects `a` in Fortran order and overwrites it. + # Batch was moved to the last axis before the call. + # Move it back to the front and reshape to the original shape. + a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape) + ipiv_h = ipiv_h.reshape((*orig_shape[:-2], k)) + + if any(dev_info_h): + diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0) + warn( + f"Diagonal number {diag_nums} are exactly zero. " + "Singular matrix.", + RuntimeWarning, + stacklevel=2, + ) + + # MKL lapack uses 1-origin while SciPy uses 0-origin + ipiv_h -= 1 + + # Return a tuple containing the factorized matrix 'a_h', + # pivot indices 'ipiv_h' + return (a_h, ipiv_h) + + a_usm_arr = dpnp.get_usm_ndarray(a) + + # Initialize lists for storing arrays and events for each batch + a_vecs = [None] * batch_size + ipiv_vecs = [None] * batch_size + dev_info_vecs = [None] * batch_size + dep_evs = _manager.submitted_events + + # Process each batch + for i in range(batch_size): + # Copy each 2D slice to a new array because getrf will destroy + # the input matrix + a_vecs[i] = dpnp.empty_like(a[i], order="F", dtype=res_type) + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr[i], + dst=a_vecs[i].get_array(), + sycl_queue=a_sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, copy_ev) + + ipiv_vecs[i] = dpnp.empty( + (k,), + dtype=dpnp.int64, + order="C", + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + + dev_info_vecs[i] = [0] + + # Call the LAPACK extension function _getrf + # to perform LU decomposition on each batch in 'a_vecs[i]' + ht_ev, getrf_ev = li._getrf( + a_sycl_queue, + a_vecs[i].get_array(), + ipiv_vecs[i].get_array(), + dev_info_vecs[i], + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, getrf_ev) + + # Reshape the results back to their original shape + out_a = dpnp.array(a_vecs).reshape(orig_shape) + out_ipiv = dpnp.array(ipiv_vecs).reshape((*orig_shape[:-2], k)) + + diag_nums = ", ".join( + str(v) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0 + ) + + if diag_nums: + warn( + f"Diagonal number {diag_nums} are exactly zero. Singular matrix.", + RuntimeWarning, + stacklevel=2, + ) + + # MKL lapack uses 1-origin while SciPy uses 0-origin + out_ipiv -= 1 + + # Return a tuple containing the factorized matrix 'out_a', + # pivot indices 'out_ipiv' + return (out_a, out_ipiv) + + def _batched_solve(a, b, exec_q, res_usm_type, res_type): """ _batched_solve(a, b, exec_q, res_usm_type, res_type) @@ -2323,7 +2476,9 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): raise ValueError("array must not contain infs or NaNs") if a.ndim > 2: - raise NotImplementedError("Batched matrices are not supported") + # SciPy always copies each 2D slice, + # so `overwrite_a` is ignored here + return _batched_lu_factor_scipy(a, res_type) _manager = dpu.SequentialOrderManager[a_sycl_queue] a_usm_arr = dpnp.get_usm_ndarray(a) From bde2d4e440dca00f45edc6cbc47a71d07654a9f6 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 21 Aug 2025 06:40:03 -0700 Subject: [PATCH 08/12] Handle empty inputs correctly for dpnp.linalg.lu_factor() --- dpnp/linalg/dpnp_utils_linalg.py | 35 ++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 92e4f45124fd..41952234e86d 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -398,7 +398,7 @@ def _batched_lu_factor(a, res_type): return (out_a, out_ipiv, out_dev_info) -def _batched_lu_factor_scipy(a, res_type): +def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals """SciPy-compatible LU factorization for batched inputs.""" # TODO: Find out at which array sizes the best performance is obtained @@ -414,6 +414,19 @@ def _batched_lu_factor_scipy(a, res_type): m, n = a.shape[-2:] k = min(m, n) orig_shape = a.shape + batch_shape = orig_shape[:-2] + + # accommodate empty arrays + if a.size == 0: + lu = dpnp.empty_like(a) + piv = dpnp.empty( + (*batch_shape, k), + dtype=dpnp.int64, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return lu, piv + # get 3d input arrays by reshape a = dpnp.reshape(a, (-1, m, n)) batch_size = a.shape[0] @@ -468,7 +481,7 @@ def _batched_lu_factor_scipy(a, res_type): # Batch was moved to the last axis before the call. # Move it back to the front and reshape to the original shape. a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape) - ipiv_h = ipiv_h.reshape((*orig_shape[:-2], k)) + ipiv_h = ipiv_h.reshape((*batch_shape, k)) if any(dev_info_h): diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0) @@ -530,7 +543,7 @@ def _batched_lu_factor_scipy(a, res_type): # Reshape the results back to their original shape out_a = dpnp.array(a_vecs).reshape(orig_shape) - out_ipiv = dpnp.array(ipiv_vecs).reshape((*orig_shape[:-2], k)) + out_ipiv = dpnp.array(ipiv_vecs).reshape((*batch_shape, k)) diag_nums = ", ".join( str(v) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0 @@ -2463,14 +2476,6 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): a_sycl_queue = a.sycl_queue a_usm_type = a.usm_type - # accommodate empty arrays - if a.size == 0: - lu = dpnp.empty_like(a) - piv = dpnp.arange( - 0, dtype=dpnp.int64, usm_type=a_usm_type, sycl_queue=a_sycl_queue - ) - return lu, piv - if check_finite: if not dpnp.isfinite(a).all(): raise ValueError("array must not contain infs or NaNs") @@ -2480,6 +2485,14 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): # so `overwrite_a` is ignored here return _batched_lu_factor_scipy(a, res_type) + # accommodate empty arrays + if a.size == 0: + lu = dpnp.empty_like(a) + piv = dpnp.arange( + 0, dtype=dpnp.int64, usm_type=a_usm_type, sycl_queue=a_sycl_queue + ) + return lu, piv + _manager = dpu.SequentialOrderManager[a_sycl_queue] a_usm_arr = dpnp.get_usm_ndarray(a) From bc50cbb02762af97d1518c9b9ffc61d660ee33b9 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 21 Aug 2025 07:19:34 -0700 Subject: [PATCH 09/12] Add tests for batched lu_factor --- dpnp/tests/test_linalg.py | 113 ++++++++++++++++++++++++++++++++++ dpnp/tests/test_sycl_queue.py | 2 +- dpnp/tests/test_usm_type.py | 2 +- 3 files changed, 115 insertions(+), 2 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index dd5daa99b74d..ad7ab568026f 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2022,6 +2022,119 @@ def test_batched_not_supported(self): assert_raises(NotImplementedError, dpnp.linalg.lu_factor, a_dp) +class TestLuFactorBatched: + @staticmethod + def _apply_pivots_rows(A_dp, piv_dp): + m = A_dp.shape[0] + rows = dpnp.arange(m) + for i in range(int(piv_dp.shape[0])): + r = int(piv_dp[i].item()) + if i != r: + tmp = rows[i].copy() + rows[i] = rows[r] + rows[r] = tmp + return A_dp[rows] + + @staticmethod + def _split_lu(lu, m, n): + L = dpnp.tril(lu, k=-1) + dpnp.fill_diagonal(L, 1) + L = L[:, : min(m, n)] + U = dpnp.triu(lu)[: min(m, n), :] + return L, U + + @pytest.mark.parametrize( + "shape", + [(2, 2, 2), (3, 4, 4), (2, 3, 5, 2), (4, 1, 3)], + ids=["(2,2,2)", "(3,4,4)", "(2,3,5,2)", "(4,1,3)"], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_lu_factor_batched(self, shape, order, dtype): + a_np = generate_random_numpy_array(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + lu, piv = dpnp.linalg.lu_factor( + a_dp, check_finite=False, overwrite_a=False + ) + + assert lu.shape == a_dp.shape + m, n = shape[-2], shape[-1] + assert piv.shape == (*shape[:-2], min(m, n)) + assert piv.dtype == dpnp.int64 + + a_3d = a_dp.reshape((-1, m, n)) + lu_3d = lu.reshape((-1, m, n)) + piv_2d = piv.reshape((-1, min(m, n))) + for i in range(a_3d.shape[0]): + L, U = self._split_lu(lu_3d[i], m, n) + A_cast = a_3d[i].astype(L.dtype, copy=False) + PA = self._apply_pivots_rows(A_cast, piv_2d[i]) + assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_overwrite(self, dtype, order): + a_dp = dpnp.arange(2 * 2 * 3, dtype=dtype).reshape(3, 2, 2, order=order) + a_dp_orig = a_dp.copy() + lu, piv = dpnp.linalg.lu_factor( + a_dp, overwrite_a=True, check_finite=False + ) + + assert lu is not a_dp + assert_allclose(a_dp, a_dp_orig) + + m = n = 2 + lu_3d = lu.reshape((-1, m, n)) + a_3d = a_dp.reshape((-1, m, n)) + piv_2d = piv.reshape((-1, min(m, n))) + for i in range(a_3d.shape[0]): + L, U = self._split_lu(lu_3d[i], m, n) + A_cast = a_3d[i].astype(L.dtype, copy=False) + PA = self._apply_pivots_rows(A_cast, piv_2d[i]) + assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize( + "shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)] + ) + def test_empty_inputs(self, shape): + a = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + + lu, piv = dpnp.linalg.lu_factor(a, check_finite=False) + assert lu.shape == shape + m, n = shape[-2:] + assert piv.shape == (*shape[:-2], min(m, n)) + + def test_strided(self): + a = ( + dpnp.arange(5 * 3 * 3, dtype=dpnp.default_float_type()).reshape( + 5, 3, 3, order="F" + ) + + 0.1 + ) + a_stride = a[::2] + lu, piv = dpnp.linalg.lu_factor(a_stride, check_finite=False) + for i in range(a_stride.shape[0]): + L, U = self._split_lu(lu[i], 3, 3) + PA = self._apply_pivots_rows( + a_stride[i].astype(L.dtype, copy=False), piv[i] + ) + assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6) + + def test_singular_matrix(self): + a = dpnp.zeros((3, 2, 2), dtype=dpnp.float64) + a[0] = dpnp.array([[1.0, 2.0], [2.0, 4.0]]) + a[1] = dpnp.eye(2) + a[2] = dpnp.array([[1.0, 1.0], [1.0, 1.0]]) + with pytest.warns(RuntimeWarning, match="Singular matrix"): + dpnp.linalg.lu_factor(a, check_finite=False) + + def test_check_finite_raises(self): + a = dpnp.ones((2, 3, 3), dtype=dpnp.float64, order="F") + a[1, 0, 0] = dpnp.nan + assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True) + + class TestMatrixPower: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index 7c08263e672c..d3d8e19439e2 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -1572,7 +1572,7 @@ def test_lstsq(self, m, n, nrhs, device): @pytest.mark.parametrize( "data", - [[[1.0, 2.0], [3.0, 5.0]], [[]]], + [[[1.0, 2.0], [3.0, 5.0]], [[]], [[[1.0, 2.0], [3.0, 5.0]]], [[[]]]], ) def test_lu_factor(self, data, device): a = dpnp.array(data, device=device) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index 6f886f8ec3c7..34fd9bbc003a 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1451,7 +1451,7 @@ def test_lstsq(self, m, n, nrhs, usm_type, usm_type_other): @pytest.mark.parametrize( "data", - [[[1.0, 2.0], [3.0, 5.0]], [[]]], + [[[1.0, 2.0], [3.0, 5.0]], [[]], [[[1.0, 2.0], [3.0, 5.0]]], [[[]]]], ) def test_lu_factor(self, data, usm_type): a = dpnp.array(data, usm_type=usm_type) From 4b221fd19d68310bd8d2b93e3226e9b8ce9b8639 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 21 Aug 2025 07:20:52 -0700 Subject: [PATCH 10/12] Remove test_batched_not_supported --- dpnp/tests/test_linalg.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index ad7ab568026f..0a7ada865e44 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2017,10 +2017,6 @@ def test_check_finite_raises(self, bad): ValueError, dpnp.linalg.lu_factor, a_dp, check_finite=True ) - def test_batched_not_supported(self): - a_dp = dpnp.ones((2, 2, 2)) - assert_raises(NotImplementedError, dpnp.linalg.lu_factor, a_dp) - class TestLuFactorBatched: @staticmethod From af74642126085f3ae505cf3eed9c73a63b4842f8 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 22 Aug 2025 06:40:18 -0700 Subject: [PATCH 11/12] Apply remarks --- dpnp/tests/test_linalg.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 0a7ada865e44..2fd001d4d042 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2022,13 +2022,17 @@ class TestLuFactorBatched: @staticmethod def _apply_pivots_rows(A_dp, piv_dp): m = A_dp.shape[0] - rows = dpnp.arange(m) - for i in range(int(piv_dp.shape[0])): - r = int(piv_dp[i].item()) + + if m == 0 or piv_dp.size == 0: + return A_dp + + rows = list(range(m)) + piv_np = dpnp.asnumpy(piv_dp) + for i, r in enumerate(piv_np): if i != r: - tmp = rows[i].copy() - rows[i] = rows[r] - rows[r] = tmp + rows[i], rows[r] = rows[r], rows[i] + + rows = dpnp.asarray(rows) return A_dp[rows] @staticmethod @@ -2118,7 +2122,7 @@ def test_strided(self): assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6) def test_singular_matrix(self): - a = dpnp.zeros((3, 2, 2), dtype=dpnp.float64) + a = dpnp.zeros((3, 2, 2), dtype=dpnp.default_float_type()) a[0] = dpnp.array([[1.0, 2.0], [2.0, 4.0]]) a[1] = dpnp.eye(2) a[2] = dpnp.array([[1.0, 1.0], [1.0, 1.0]]) @@ -2126,7 +2130,7 @@ def test_singular_matrix(self): dpnp.linalg.lu_factor(a, check_finite=False) def test_check_finite_raises(self): - a = dpnp.ones((2, 3, 3), dtype=dpnp.float64, order="F") + a = dpnp.ones((2, 3, 3), dtype=dpnp.default_float_type(), order="F") a[1, 0, 0] = dpnp.nan assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True) From 46deac333d6c01473179be6026f19dd96abc6124 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 26 Aug 2025 04:43:27 -0700 Subject: [PATCH 12/12] Update _batched_lu_factor_scipy by using single allocation with batch-axis views --- dpnp/linalg/dpnp_utils_linalg.py | 153 +++++++++++-------------------- 1 file changed, 56 insertions(+), 97 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 41952234e86d..803f8e7326c8 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -402,9 +402,9 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals """SciPy-compatible LU factorization for batched inputs.""" # TODO: Find out at which array sizes the best performance is obtained - # getrf_batch implementation shows slow results with large arrays on GPU. + # getrf_batch can be slow on large GPU arrays. # Use getrf_batch only on CPU. - # On GPU call getrf for each two-dimensional array by loop + # On GPU fall back to calling getrf per 2D slice. use_batch = a.sycl_device.has_aspect_cpu a_sycl_queue = a.sycl_queue @@ -416,7 +416,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals orig_shape = a.shape batch_shape = orig_shape[:-2] - # accommodate empty arrays + # handle empty input if a.size == 0: lu = dpnp.empty_like(a) piv = dpnp.empty( @@ -431,32 +431,33 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals a = dpnp.reshape(a, (-1, m, n)) batch_size = a.shape[0] - if use_batch: - # Reorder the elements by moving the last two axes of `a` to the front - # to match fortran-like array order which is assumed by getrf_batch - a = dpnp.moveaxis(a, 0, -1) + # Move batch axis to the end (m, n, batch) in Fortran order: + # required by getrf_batch + # and ensures each a[..., i] is F-contiguous for getrf + a = dpnp.moveaxis(a, 0, -1) - a_usm_arr = dpnp.get_usm_ndarray(a) + a_usm_arr = dpnp.get_usm_ndarray(a) - # `a` must be copied because getrf_batch destroys the input matrix - a_h = dpnp.empty_like(a, order="F", dtype=res_type) - ipiv_h = dpnp.empty( - (batch_size, k), - dtype=dpnp.int64, - order="C", - usm_type=a_usm_type, - sycl_queue=a_sycl_queue, - ) + # `a` must be copied because getrf/getrf_batch destroys the input matrix + a_h = dpnp.empty_like(a, order="F", dtype=res_type) + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_h.get_array(), + sycl_queue=a_sycl_queue, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, copy_ev) - dev_info_h = [0] * batch_size + ipiv_h = dpnp.empty( + (batch_size, k), + dtype=dpnp.int64, + order="C", + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) - ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, - dst=a_h.get_array(), - sycl_queue=a_sycl_queue, - depends=_manager.submitted_events, - ) - _manager.add_event_pair(ht_ev, copy_ev) + if use_batch: + dev_info_h = [0] * batch_size ipiv_stride = k a_stride = a_h.strides[-1] @@ -477,12 +478,6 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals ) _manager.add_event_pair(ht_ev, getrf_ev) - # getrf_batch expects `a` in Fortran order and overwrites it. - # Batch was moved to the last axis before the call. - # Move it back to the front and reshape to the original shape. - a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape) - ipiv_h = ipiv_h.reshape((*batch_shape, k)) - if any(dev_info_h): diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0) warn( @@ -491,77 +486,41 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals RuntimeWarning, stacklevel=2, ) + else: + dev_info_vecs = [[0] for _ in range(batch_size)] + + # Sequential LU factorization using getrf per slice + for i in range(batch_size): + ht_ev, getrf_ev = li._getrf( + a_sycl_queue, + a_h[..., i].get_array(), + ipiv_h[i].get_array(), + dev_info_vecs[i], + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, getrf_ev) - # MKL lapack uses 1-origin while SciPy uses 0-origin - ipiv_h -= 1 - - # Return a tuple containing the factorized matrix 'a_h', - # pivot indices 'ipiv_h' - return (a_h, ipiv_h) - - a_usm_arr = dpnp.get_usm_ndarray(a) - - # Initialize lists for storing arrays and events for each batch - a_vecs = [None] * batch_size - ipiv_vecs = [None] * batch_size - dev_info_vecs = [None] * batch_size - dep_evs = _manager.submitted_events - - # Process each batch - for i in range(batch_size): - # Copy each 2D slice to a new array because getrf will destroy - # the input matrix - a_vecs[i] = dpnp.empty_like(a[i], order="F", dtype=res_type) - ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr[i], - dst=a_vecs[i].get_array(), - sycl_queue=a_sycl_queue, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, copy_ev) - - ipiv_vecs[i] = dpnp.empty( - (k,), - dtype=dpnp.int64, - order="C", - usm_type=a_usm_type, - sycl_queue=a_sycl_queue, + diag_nums = ", ".join( + str(v) for info in dev_info_vecs for v in info if v > 0 ) + if diag_nums: + warn( + f"Diagonal number {diag_nums} are exactly zero. " + "Singular matrix.", + RuntimeWarning, + stacklevel=2, + ) - dev_info_vecs[i] = [0] - - # Call the LAPACK extension function _getrf - # to perform LU decomposition on each batch in 'a_vecs[i]' - ht_ev, getrf_ev = li._getrf( - a_sycl_queue, - a_vecs[i].get_array(), - ipiv_vecs[i].get_array(), - dev_info_vecs[i], - depends=[copy_ev], - ) - _manager.add_event_pair(ht_ev, getrf_ev) - - # Reshape the results back to their original shape - out_a = dpnp.array(a_vecs).reshape(orig_shape) - out_ipiv = dpnp.array(ipiv_vecs).reshape((*batch_shape, k)) - - diag_nums = ", ".join( - str(v) for dev_info_h in dev_info_vecs for v in dev_info_h if v > 0 - ) - - if diag_nums: - warn( - f"Diagonal number {diag_nums} are exactly zero. Singular matrix.", - RuntimeWarning, - stacklevel=2, - ) + # Restore original shape: move batch axis back and reshape + a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape) + ipiv_h = ipiv_h.reshape((*batch_shape, k)) - # MKL lapack uses 1-origin while SciPy uses 0-origin - out_ipiv -= 1 + # oneMKL LAPACK uses 1-origin while SciPy uses 0-origin + ipiv_h -= 1 - # Return a tuple containing the factorized matrix 'out_a', - # pivot indices 'out_ipiv' - return (out_a, out_ipiv) + # Return a tuple containing the factorized matrix 'a_h', + # pivot indices 'ipiv_h' + return (a_h, ipiv_h) def _batched_solve(a, b, exec_q, res_usm_type, res_type):