Skip to content

Commit c76aba5

Browse files
Extend getrf_batch to support non-square matrices
1 parent 25f7ab3 commit c76aba5

File tree

4 files changed

+27
-14
lines changed

4 files changed

+27
-14
lines changed

dpnp/backend/extensions/lapack/getrf.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ extern std::pair<sycl::event, sycl::event>
4444
const dpctl::tensor::usm_ndarray &a_array,
4545
const dpctl::tensor::usm_ndarray &ipiv_array,
4646
py::list dev_info,
47+
std::int64_t m,
4748
std::int64_t n,
4849
std::int64_t stride_a,
4950
std::int64_t stride_ipiv,

dpnp/backend/extensions/lapack/getrf_batch.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ namespace type_utils = dpctl::tensor::type_utils;
4646
typedef sycl::event (*getrf_batch_impl_fn_ptr_t)(
4747
sycl::queue &,
4848
std::int64_t,
49+
std::int64_t,
4950
char *,
5051
std::int64_t,
5152
std::int64_t,
@@ -61,6 +62,7 @@ static getrf_batch_impl_fn_ptr_t
6162

6263
template <typename T>
6364
static sycl::event getrf_batch_impl(sycl::queue &exec_q,
65+
std::int64_t m,
6466
std::int64_t n,
6567
char *in_a,
6668
std::int64_t lda,
@@ -77,7 +79,7 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,
7779
T *a = reinterpret_cast<T *>(in_a);
7880

7981
const std::int64_t scratchpad_size =
80-
mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, n, n, lda, stride_a,
82+
mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, m, n, lda, stride_a,
8183
stride_ipiv, batch_size);
8284
T *scratchpad = nullptr;
8385

@@ -91,11 +93,11 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,
9193

9294
getrf_batch_event = mkl_lapack::getrf_batch(
9395
exec_q,
94-
n, // The order of each square matrix in the batch; (0 ≤ n).
96+
m, // The number of rows in each matrix in the batch; (0 ≤ m).
9597
// It must be a non-negative integer.
9698
n, // The number of columns in each matrix in the batch; (0 ≤ n).
9799
// It must be a non-negative integer.
98-
a, // Pointer to the batch of square matrices, each of size (n x n).
100+
a, // Pointer to the batch of input matrices, each of size (m x n).
99101
lda, // The leading dimension of each matrix in the batch.
100102
stride_a, // Stride between consecutive matrices in the batch.
101103
ipiv, // Pointer to the array of pivot indices for each matrix in
@@ -179,6 +181,7 @@ std::pair<sycl::event, sycl::event>
179181
const dpctl::tensor::usm_ndarray &a_array,
180182
const dpctl::tensor::usm_ndarray &ipiv_array,
181183
py::list dev_info,
184+
std::int64_t m,
182185
std::int64_t n,
183186
std::int64_t stride_a,
184187
std::int64_t stride_ipiv,
@@ -191,21 +194,21 @@ std::pair<sycl::event, sycl::event>
191194
if (a_array_nd < 3) {
192195
throw py::value_error(
193196
"The input array has ndim=" + std::to_string(a_array_nd) +
194-
", but an array with ndim >= 3 is expected.");
197+
", but an array with ndim >= 3 is expected");
195198
}
196199

197200
if (ipiv_array_nd != 2) {
198201
throw py::value_error("The array of pivot indices has ndim=" +
199202
std::to_string(ipiv_array_nd) +
200-
", but a 2-dimensional array is expected.");
203+
", but a 2-dimensional array is expected");
201204
}
202205

203206
const int dev_info_size = py::len(dev_info);
204207
if (dev_info_size != batch_size) {
205208
throw py::value_error("The size of 'dev_info' (" +
206209
std::to_string(dev_info_size) +
207210
") does not match the expected batch size (" +
208-
std::to_string(batch_size) + ").");
211+
std::to_string(batch_size) + ")");
209212
}
210213

211214
// check compatibility of execution queue and allocation queue
@@ -241,27 +244,34 @@ std::pair<sycl::event, sycl::event>
241244
if (getrf_batch_fn == nullptr) {
242245
throw py::value_error(
243246
"No getrf_batch implementation defined for the provided type "
244-
"of the input matrix.");
247+
"of the input matrix");
245248
}
246249

247250
auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
248251
int ipiv_array_type_id =
249252
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
250253

251254
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
252-
throw py::value_error("The type of 'ipiv_array' must be int64.");
255+
throw py::value_error("The type of 'ipiv_array' must be int64");
256+
}
257+
258+
const py::ssize_t *ipiv_array_shape = ipiv_array.get_shape_raw();
259+
if (ipiv_array_shape[0] != batch_size ||
260+
ipiv_array_shape[1] != std::min(m, n)) {
261+
throw py::value_error(
262+
"The shape of 'ipiv_array' must be (batch_size, min(m, n))");
253263
}
254264

255265
char *a_array_data = a_array.get_data();
256-
const std::int64_t lda = std::max<size_t>(1UL, n);
266+
const std::int64_t lda = std::max<size_t>(1UL, m);
257267

258268
char *ipiv_array_data = ipiv_array.get_data();
259269
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);
260270

261271
std::vector<sycl::event> host_task_events;
262272
sycl::event getrf_batch_ev = getrf_batch_fn(
263-
exec_q, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, batch_size,
264-
dev_info, host_task_events, depends);
273+
exec_q, m, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv,
274+
batch_size, dev_info, host_task_events, depends);
265275

266276
sycl::event args_ev = dpctl::utils::keep_args_alive(
267277
exec_q, {a_array, ipiv_array}, host_task_events);

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ PYBIND11_MODULE(_lapack_impl, m)
141141

142142
m.def("_getrf_batch", &lapack_ext::getrf_batch,
143143
"Call `getrf_batch` from OneMKL LAPACK library to return "
144-
"the LU factorization of a batch of general n x n matrices",
144+
"the LU factorization of a batch of general m x n matrices",
145145
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
146-
py::arg("dev_info_array"), py::arg("n"), py::arg("stride_a"),
147-
py::arg("stride_ipiv"), py::arg("batch_size"),
146+
py::arg("dev_info_array"), py::arg("m"), py::arg("n"),
147+
py::arg("stride_a"), py::arg("stride_ipiv"), py::arg("batch_size"),
148148
py::arg("depends") = py::list());
149149

150150
m.def("_getri_batch", &lapack_ext::getri_batch,

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def _batched_inv(a, res_type):
246246
ipiv_h.get_array(),
247247
dev_info,
248248
n,
249+
n,
249250
a_stride,
250251
ipiv_stride,
251252
batch_size,
@@ -327,6 +328,7 @@ def _batched_lu_factor(a, res_type):
327328
ipiv_h.get_array(),
328329
dev_info_h,
329330
n,
331+
n,
330332
a_stride,
331333
ipiv_stride,
332334
batch_size,

0 commit comments

Comments
 (0)