Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions dpnp/backend/extensions/lapack/getrf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*getrf_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
const std::int64_t,
char *,
std::int64_t,
Expand All @@ -56,6 +57,7 @@ static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
static sycl::event getrf_impl(sycl::queue &exec_q,
const std::int64_t m,
const std::int64_t n,
char *in_a,
std::int64_t lda,
Expand All @@ -82,11 +84,11 @@ static sycl::event getrf_impl(sycl::queue &exec_q,

getrf_event = mkl_lapack::getrf(
exec_q,
n, // The order of the square matrix A (0 ≤ n).
m, // The number of rows in the input matrix A (0 ≤ m).
// It must be a non-negative integer.
n, // The number of columns in the square matrix A (0 ≤ n).
n, // The number of columns in the input matrix A (0 ≤ n).
// It must be a non-negative integer.
a, // Pointer to the square matrix A (n x n).
a, // Pointer to the input matrix A (n x n).
lda, // The leading dimension of matrix A.
// It must be at least max(1, n).
ipiv, // Pointer to the output array of pivot indices.
Expand All @@ -99,7 +101,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,

if (info < 0) {
error_msg << "Parameter number " << -info
<< " had an illegal value.";
<< " had an illegal value";
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
Expand Down Expand Up @@ -168,13 +170,13 @@ std::pair<sycl::event, sycl::event>
if (a_array_nd != 2) {
throw py::value_error(
"The input array has ndim=" + std::to_string(a_array_nd) +
", but a 2-dimensional array is expected.");
", but a 2-dimensional array is expected");
}

if (ipiv_array_nd != 1) {
throw py::value_error("The array of pivot indices has ndim=" +
std::to_string(ipiv_array_nd) +
", but a 1-dimensional array is expected.");
", but a 1-dimensional array is expected");
}

// check compatibility of execution queue and allocation queue
Expand All @@ -190,10 +192,11 @@ std::pair<sycl::event, sycl::event>
}

bool is_a_array_c_contig = a_array.is_c_contiguous();
bool is_a_array_f_contig = a_array.is_f_contiguous();
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
if (!is_a_array_c_contig) {
if (!is_a_array_c_contig && !is_a_array_f_contig) {
throw py::value_error("The input array "
"must be C-contiguous");
"must be contiguous");
}
if (!is_ipiv_array_c_contig) {
throw py::value_error("The array of pivot indices "
Expand All @@ -208,27 +211,33 @@ std::pair<sycl::event, sycl::event>
if (getrf_fn == nullptr) {
throw py::value_error(
"No getrf implementation defined for the provided type "
"of the input matrix.");
"of the input matrix");
}

auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
int ipiv_array_type_id =
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());

if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
throw py::value_error("The type of 'ipiv_array' must be int64.");
throw py::value_error("The type of 'ipiv_array' must be int64");
}

const std::int64_t n = a_array.get_shape_raw()[0];
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
const std::int64_t m = a_array_shape[0];
const std::int64_t n = a_array_shape[1];
const std::int64_t lda = std::max<size_t>(1UL, m);

if (ipiv_array.get_size() != std::min(m, n)) {
throw py::value_error("The size of 'ipiv_array' must be min(m, n)");
}

char *a_array_data = a_array.get_data();
const std::int64_t lda = std::max<size_t>(1UL, n);

char *ipiv_array_data = ipiv_array.get_data();
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);

std::vector<sycl::event> host_task_events;
sycl::event getrf_ev = getrf_fn(exec_q, n, a_array_data, lda, d_ipiv,
sycl::event getrf_ev = getrf_fn(exec_q, m, n, a_array_data, lda, d_ipiv,
dev_info, host_task_events, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
Expand Down
5 changes: 3 additions & 2 deletions dpnp/backend/extensions/lapack/getrf_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,11 @@ std::pair<sycl::event, sycl::event>
}

bool is_a_array_c_contig = a_array.is_c_contiguous();
bool is_a_array_f_contig = a_array.is_f_contiguous();
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
if (!is_a_array_c_contig) {
if (!is_a_array_c_contig && !is_a_array_f_contig) {
throw py::value_error("The input array "
"must be C-contiguous");
"must be must contiguous");
}
if (!is_ipiv_array_c_contig) {
throw py::value_error("The array of pivot indices "
Expand Down
64 changes: 64 additions & 0 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
dpnp_eigh,
dpnp_inv,
dpnp_lstsq,
dpnp_lu_factor,
dpnp_matrix_power,
dpnp_matrix_rank,
dpnp_multi_dot,
Expand All @@ -79,6 +80,7 @@
"eigvalsh",
"inv",
"lstsq",
"lu_factor",
"matmul",
"matrix_norm",
"matrix_power",
Expand Down Expand Up @@ -901,6 +903,68 @@ def lstsq(a, b, rcond=None):
return dpnp_lstsq(a, b, rcond=rcond)


def lu_factor(a, overwrite_a=False, check_finite=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for the function is not rendering

"""
Compute the pivoted LU decomposition of a matrix.

The decomposition is::

A = P @ L @ U

where `P` is a permutation matrix, `L` is lower triangular with unit
diagonal elements, and `U` is upper triangular.

Parameters
----------
a : (M, N) {dpnp.ndarray, usm_ndarray}
Input array to decompose.
overwrite_a : {None, bool}, optional
Whether to overwrite data in `a` (may increase performance)
Default: ``False``.
Comment on lines +922 to +923
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        Whether to overwrite data in `a` (may increase performance).

        Default: ``False``.

check_finite : {None, bool}, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


        Default: ``True``.


Returns
-------
lu :(M, N) dpnp.ndarray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sems space is missing: lu : (M, N)

Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv (K, ): dpnp.ndarray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

piv : (K, ) dpnp.ndarray

Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
``K = min(M, N)``.

Warning
-------
This function synchronizes in order to validate array elements
when ``check_finite=True``.

Limitations
-----------
Only two-dimensional input matrices are supported.
Otherwise, the function raises ``NotImplementedError`` exception.

Examples
--------
>>> import dpnp as np
>>> a = np.array([[4., 3.], [6., 3.]])
>>> lu, piv = np.linalg.lu_factor(a)
>>> lu
array([[6. , 3. ],
[0.66666667, 1. ]])
>>> piv
array([1, 1])

"""

dpnp.check_supported_arrays_type(a)
assert_stacked_2d(a)

return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite)


def matmul(x1, x2, /):
"""
Computes the matrix product.
Expand Down
Loading
Loading