diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index b7ac5311cb34..8185f3a06a79 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -166,6 +166,7 @@ std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, + oneapi::mkl::transpose trans, const std::vector &depends) { const int a_array_nd = a_array.get_ndim(); @@ -264,12 +265,6 @@ std::pair const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); - // Use transpose::T if the LU-factorized array is passed as C-contiguous. - // For F-contiguous we use transpose::N. - oneapi::mkl::transpose trans = is_a_array_c_contig - ? oneapi::mkl::transpose::T - : oneapi::mkl::transpose::N; - char *a_array_data = a_array.get_data(); char *b_array_data = b_array.get_data(); char *ipiv_array_data = ipiv_array.get_data(); diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 8fa4889c99af..d8952f3f0b3f 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -37,6 +37,7 @@ extern std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, + oneapi::mkl::transpose trans, const std::vector &depends = {}); extern void init_getrs_dispatch_vector(void); diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 83a0555f808b..46471cc2f366 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -76,6 +76,13 @@ void init_dispatch_tables(void) PYBIND11_MODULE(_lapack_impl, m) { + // Expose oneMKL transpose enum to Python + py::enum_(m, "Transpose") + .value("N", oneapi::mkl::transpose::N) + .value("T", oneapi::mkl::transpose::T) + .value("C", oneapi::mkl::transpose::C) + .export_values(); // Optional, allows access like `Transpose.N` + // Register a custom LinAlgError exception in the dpnp.linalg submodule py::module_ linalg_module = py::module_::import("dpnp.linalg"); py::register_exception( @@ -160,7 +167,8 @@ PYBIND11_MODULE(_lapack_impl, m) "the solves of linear equations with an LU-factored " "square coefficient matrix, with multiple right-hand sides", py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), - py::arg("b_array"), py::arg("depends") = py::list()); + py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N, + py::arg("depends") = py::list()); m.def("_orgqr_batch", &lapack_ext::orgqr_batch, "Call `_orgqr_batch` from OneMKL LAPACK library to return "