Skip to content

Commit 07d455e

Browse files
Extend getrs with trans_code argument (#2563)
This PR suggests expanding `getrs` function by adding `trans` argument and handling its values to match the behavior of `oneapi::mkl::lapack::getrs`
1 parent 1acd558 commit 07d455e

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

dpnp/backend/extensions/lapack/getrs.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ std::pair<sycl::event, sycl::event>
166166
const dpctl::tensor::usm_ndarray &a_array,
167167
const dpctl::tensor::usm_ndarray &ipiv_array,
168168
const dpctl::tensor::usm_ndarray &b_array,
169+
oneapi::mkl::transpose trans,
169170
const std::vector<sycl::event> &depends)
170171
{
171172
const int a_array_nd = a_array.get_ndim();
@@ -264,12 +265,6 @@ std::pair<sycl::event, sycl::event>
264265
const std::int64_t lda = std::max<size_t>(1UL, n);
265266
const std::int64_t ldb = std::max<size_t>(1UL, n);
266267

267-
// Use transpose::T if the LU-factorized array is passed as C-contiguous.
268-
// For F-contiguous we use transpose::N.
269-
oneapi::mkl::transpose trans = is_a_array_c_contig
270-
? oneapi::mkl::transpose::T
271-
: oneapi::mkl::transpose::N;
272-
273268
char *a_array_data = a_array.get_data();
274269
char *b_array_data = b_array.get_data();
275270
char *ipiv_array_data = ipiv_array.get_data();

dpnp/backend/extensions/lapack/getrs.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ extern std::pair<sycl::event, sycl::event>
3737
const dpctl::tensor::usm_ndarray &a_array,
3838
const dpctl::tensor::usm_ndarray &ipiv_array,
3939
const dpctl::tensor::usm_ndarray &b_array,
40+
oneapi::mkl::transpose trans,
4041
const std::vector<sycl::event> &depends = {});
4142

4243
extern void init_getrs_dispatch_vector(void);

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ void init_dispatch_tables(void)
7676

7777
PYBIND11_MODULE(_lapack_impl, m)
7878
{
79+
// Expose oneMKL transpose enum to Python
80+
py::enum_<oneapi::mkl::transpose>(m, "Transpose")
81+
.value("N", oneapi::mkl::transpose::N)
82+
.value("T", oneapi::mkl::transpose::T)
83+
.value("C", oneapi::mkl::transpose::C)
84+
.export_values(); // Optional, allows access like `Transpose.N`
85+
7986
// Register a custom LinAlgError exception in the dpnp.linalg submodule
8087
py::module_ linalg_module = py::module_::import("dpnp.linalg");
8188
py::register_exception<lapack_ext::LinAlgError>(
@@ -160,7 +167,8 @@ PYBIND11_MODULE(_lapack_impl, m)
160167
"the solves of linear equations with an LU-factored "
161168
"square coefficient matrix, with multiple right-hand sides",
162169
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
163-
py::arg("b_array"), py::arg("depends") = py::list());
170+
py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N,
171+
py::arg("depends") = py::list());
164172

165173
m.def("_orgqr_batch", &lapack_ext::orgqr_batch,
166174
"Call `_orgqr_batch` from OneMKL LAPACK library to return "

0 commit comments

Comments
 (0)