Skip to content

Commit 932a23a

Browse files
Extend getrs with trans_code argument
1 parent 1b94a40 commit 932a23a

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

dpnp/backend/extensions/lapack/getrs.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ std::pair<sycl::event, sycl::event>
166166
const dpctl::tensor::usm_ndarray &a_array,
167167
const dpctl::tensor::usm_ndarray &ipiv_array,
168168
const dpctl::tensor::usm_ndarray &b_array,
169+
const int trans_code,
169170
const std::vector<sycl::event> &depends)
170171
{
171172
const int a_array_nd = a_array.get_ndim();
@@ -264,11 +265,20 @@ std::pair<sycl::event, sycl::event>
264265
const std::int64_t lda = std::max<size_t>(1UL, n);
265266
const std::int64_t ldb = std::max<size_t>(1UL, n);
266267

267-
// Use transpose::T if the LU-factorized array is passed as C-contiguous.
268-
// For F-contiguous we use transpose::N.
269-
oneapi::mkl::transpose trans = is_a_array_c_contig
270-
? oneapi::mkl::transpose::T
271-
: oneapi::mkl::transpose::N;
268+
oneapi::mkl::transpose trans;
269+
switch (trans_code) {
270+
case 0:
271+
trans = oneapi::mkl::transpose::N;
272+
break;
273+
case 1:
274+
trans = oneapi::mkl::transpose::T;
275+
break;
276+
case 2:
277+
trans = oneapi::mkl::transpose::C;
278+
break;
279+
default:
280+
throw py::value_error("`trans_code` must be 0 (N), 1 (T), or 2 (C)");
281+
}
272282

273283
char *a_array_data = a_array.get_data();
274284
char *b_array_data = b_array.get_data();

dpnp/backend/extensions/lapack/getrs.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ extern std::pair<sycl::event, sycl::event>
3737
const dpctl::tensor::usm_ndarray &a_array,
3838
const dpctl::tensor::usm_ndarray &ipiv_array,
3939
const dpctl::tensor::usm_ndarray &b_array,
40+
const int trans_code,
4041
const std::vector<sycl::event> &depends = {});
4142

4243
extern void init_getrs_dispatch_vector(void);

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ PYBIND11_MODULE(_lapack_impl, m)
160160
"the solves of linear equations with an LU-factored "
161161
"square coefficient matrix, with multiple right-hand sides",
162162
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
163-
py::arg("b_array"), py::arg("depends") = py::list());
163+
py::arg("b_array"), py::arg("trans_code"),
164+
py::arg("depends") = py::list());
164165

165166
m.def("_orgqr_batch", &lapack_ext::orgqr_batch,
166167
"Call `_orgqr_batch` from OneMKL LAPACK library to return "

0 commit comments

Comments
 (0)