Skip to content

Commit bfb82b6

Browse files
Expose Transpose enum to Python via pybind11
1 parent ce878e6 commit bfb82b6

File tree

3 files changed

+10
-18
lines changed

3 files changed

+10
-18
lines changed

dpnp/backend/extensions/lapack/getrs.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ std::pair<sycl::event, sycl::event>
166166
const dpctl::tensor::usm_ndarray &a_array,
167167
const dpctl::tensor::usm_ndarray &ipiv_array,
168168
const dpctl::tensor::usm_ndarray &b_array,
169-
const int trans_code,
169+
oneapi::mkl::transpose trans,
170170
const std::vector<sycl::event> &depends)
171171
{
172172
const int a_array_nd = a_array.get_ndim();
@@ -265,21 +265,6 @@ std::pair<sycl::event, sycl::event>
265265
const std::int64_t lda = std::max<size_t>(1UL, n);
266266
const std::int64_t ldb = std::max<size_t>(1UL, n);
267267

268-
oneapi::mkl::transpose trans;
269-
switch (trans_code) {
270-
case 0:
271-
trans = oneapi::mkl::transpose::N;
272-
break;
273-
case 1:
274-
trans = oneapi::mkl::transpose::T;
275-
break;
276-
case 2:
277-
trans = oneapi::mkl::transpose::C;
278-
break;
279-
default:
280-
throw py::value_error("`trans_code` must be 0 (N), 1 (T), or 2 (C)");
281-
}
282-
283268
char *a_array_data = a_array.get_data();
284269
char *b_array_data = b_array.get_data();
285270
char *ipiv_array_data = ipiv_array.get_data();

dpnp/backend/extensions/lapack/getrs.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ extern std::pair<sycl::event, sycl::event>
3737
const dpctl::tensor::usm_ndarray &a_array,
3838
const dpctl::tensor::usm_ndarray &ipiv_array,
3939
const dpctl::tensor::usm_ndarray &b_array,
40-
const int trans_code,
40+
oneapi::mkl::transpose trans,
4141
const std::vector<sycl::event> &depends = {});
4242

4343
extern void init_getrs_dispatch_vector(void);

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ void init_dispatch_tables(void)
7676

7777
PYBIND11_MODULE(_lapack_impl, m)
7878
{
79+
// Expose oneMKL transpose enum to Python
80+
py::enum_<oneapi::mkl::transpose>(m, "Transpose")
81+
.value("N", oneapi::mkl::transpose::N)
82+
.value("T", oneapi::mkl::transpose::T)
83+
.value("C", oneapi::mkl::transpose::C)
84+
.export_values(); // Optional, allows access like `Transpose.N`
85+
7986
// Register a custom LinAlgError exception in the dpnp.linalg submodule
8087
py::module_ linalg_module = py::module_::import("dpnp.linalg");
8188
py::register_exception<lapack_ext::LinAlgError>(
@@ -160,7 +167,7 @@ PYBIND11_MODULE(_lapack_impl, m)
160167
"the solves of linear equations with an LU-factored "
161168
"square coefficient matrix, with multiple right-hand sides",
162169
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
163-
py::arg("b_array"), py::arg("trans_code"),
170+
py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N,
164171
py::arg("depends") = py::list());
165172

166173
m.def("_orgqr_batch", &lapack_ext::orgqr_batch,

0 commit comments

Comments
 (0)