Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions dpnp/backend/extensions/lapack/getrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ std::pair<sycl::event, sycl::event>
const dpctl::tensor::usm_ndarray &a_array,
const dpctl::tensor::usm_ndarray &ipiv_array,
const dpctl::tensor::usm_ndarray &b_array,
oneapi::mkl::transpose trans,
const std::vector<sycl::event> &depends)
{
const int a_array_nd = a_array.get_ndim();
Expand Down Expand Up @@ -264,12 +265,6 @@ std::pair<sycl::event, sycl::event>
const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t ldb = std::max<size_t>(1UL, n);

// Use transpose::T if the LU-factorized array is passed as C-contiguous.
// For F-contiguous we use transpose::N.
oneapi::mkl::transpose trans = is_a_array_c_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;

char *a_array_data = a_array.get_data();
char *b_array_data = b_array.get_data();
char *ipiv_array_data = ipiv_array.get_data();
Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/extensions/lapack/getrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ extern std::pair<sycl::event, sycl::event>
const dpctl::tensor::usm_ndarray &a_array,
const dpctl::tensor::usm_ndarray &ipiv_array,
const dpctl::tensor::usm_ndarray &b_array,
oneapi::mkl::transpose trans,
const std::vector<sycl::event> &depends = {});

extern void init_getrs_dispatch_vector(void);
Expand Down
10 changes: 9 additions & 1 deletion dpnp/backend/extensions/lapack/lapack_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ void init_dispatch_tables(void)

PYBIND11_MODULE(_lapack_impl, m)
{
// Expose oneMKL transpose enum to Python
py::enum_<oneapi::mkl::transpose>(m, "Transpose")
.value("N", oneapi::mkl::transpose::N)
.value("T", oneapi::mkl::transpose::T)
.value("C", oneapi::mkl::transpose::C)
.export_values(); // Optional, allows access like `Transpose.N`

// Register a custom LinAlgError exception in the dpnp.linalg submodule
py::module_ linalg_module = py::module_::import("dpnp.linalg");
py::register_exception<lapack_ext::LinAlgError>(
Expand Down Expand Up @@ -160,7 +167,8 @@ PYBIND11_MODULE(_lapack_impl, m)
"the solves of linear equations with an LU-factored "
"square coefficient matrix, with multiple right-hand sides",
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
py::arg("b_array"), py::arg("depends") = py::list());
py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N,
py::arg("depends") = py::list());

m.def("_orgqr_batch", &lapack_ext::orgqr_batch,
"Call `_orgqr_batch` from OneMKL LAPACK library to return "
Expand Down
Loading