Skip to content

Commit 36d21df

Browse files
Implement getrs_batch lapack extension
1 parent cb067e6 commit 36d21df

File tree

5 files changed

+392
-1
lines changed

5 files changed

+392
-1
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ set(_module_src
3737
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
3939
${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp
40+
${CMAKE_CURRENT_SOURCE_DIR}/getrs_batch.cpp
4041
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/heevd_batch.cpp
4243
${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp

dpnp/backend/extensions/lapack/getrs.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,20 @@ extern std::pair<sycl::event, sycl::event>
4040
oneapi::mkl::transpose trans,
4141
const std::vector<sycl::event> &depends = {});
4242

43+
extern std::pair<sycl::event, sycl::event>
44+
getrs_batch(sycl::queue &exec_q,
45+
const dpctl::tensor::usm_ndarray &a_array,
46+
const dpctl::tensor::usm_ndarray &ipiv_array,
47+
const dpctl::tensor::usm_ndarray &b_array,
48+
oneapi::mkl::transpose trans,
49+
std::int64_t n,
50+
std::int64_t nrhs,
51+
std::int64_t stride_a,
52+
std::int64_t stride_ipiv,
53+
std::int64_t stride_b,
54+
std::int64_t batch_size,
55+
const std::vector<sycl::event> &depends = {});
56+
4357
extern void init_getrs_dispatch_vector(void);
58+
extern void init_getrs_batch_dispatch_vector(void);
4459
} // namespace dpnp::extensions::lapack
Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <cstddef>
27+
#include <stdexcept>
28+
#include <vector>
29+
30+
#include <pybind11/pybind11.h>
31+
#include <sycl/sycl.hpp>
32+
33+
// dpctl tensor headers
34+
#include "utils/memory_overlap.hpp"
35+
#include "utils/sycl_alloc_utils.hpp"
36+
#include "utils/type_dispatch.hpp"
37+
#include "utils/type_utils.hpp"
38+
39+
#include "getrs.hpp"
40+
#include "linalg_exceptions.hpp"
41+
#include "types_matrix.hpp"
42+
43+
namespace dpnp::extensions::lapack
44+
{
45+
namespace mkl_lapack = oneapi::mkl::lapack;
46+
namespace py = pybind11;
47+
namespace type_utils = dpctl::tensor::type_utils;
48+
namespace td_ns = dpctl::tensor::type_dispatch;
49+
50+
typedef sycl::event (*getrs_batch_impl_fn_ptr_t)(
51+
sycl::queue &,
52+
oneapi::mkl::transpose, // trans
53+
const std::int64_t, // n
54+
const std::int64_t, // nrhs
55+
char *, // a
56+
std::int64_t, // lda
57+
std::int64_t, // stride_a
58+
std::int64_t *, // ipiv
59+
std::int64_t, // stride_ipiv
60+
char *, // b
61+
std::int64_t, // ldb
62+
std::int64_t, // stride_b
63+
std::int64_t, // batch_size
64+
std::vector<sycl::event> &,
65+
const std::vector<sycl::event> &);
66+
67+
static getrs_batch_impl_fn_ptr_t getrs_batch_dispatch_vector[td_ns::num_types];
68+
69+
template <typename T>
70+
static sycl::event getrs_batch_impl(sycl::queue &exec_q,
71+
oneapi::mkl::transpose trans,
72+
const std::int64_t n,
73+
const std::int64_t nrhs,
74+
char *in_a,
75+
std::int64_t lda,
76+
std::int64_t stride_a,
77+
std::int64_t *ipiv,
78+
std::int64_t stride_ipiv,
79+
char *in_b,
80+
std::int64_t ldb,
81+
std::int64_t stride_b,
82+
std::int64_t batch_size,
83+
std::vector<sycl::event> &host_task_events,
84+
const std::vector<sycl::event> &depends)
85+
{
86+
type_utils::validate_type_for_device<T>(exec_q);
87+
88+
T *a = reinterpret_cast<T *>(in_a);
89+
T *b = reinterpret_cast<T *>(in_b);
90+
91+
const std::int64_t scratchpad_size =
92+
mkl_lapack::getrs_batch_scratchpad_size<T>(exec_q, trans, n, nrhs, lda,
93+
stride_a, stride_ipiv, ldb,
94+
stride_b, batch_size);
95+
T *scratchpad = nullptr;
96+
97+
std::stringstream error_msg;
98+
std::int64_t info = 0;
99+
bool is_exception_caught = false;
100+
101+
sycl::event getrs_batch_event;
102+
try {
103+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
104+
105+
getrs_batch_event = mkl_lapack::getrs_batch(
106+
exec_q,
107+
trans, // Specifies the operation: whether or not to transpose
108+
// matrix A. Can be 'N' for no transpose, 'T' for transpose,
109+
// and 'C' for conjugate transpose.
110+
n, // The order of the square matrix A
111+
// and the number of rows in matrix B (0 ≤ n).
112+
// It must be a non-negative integer.
113+
nrhs, // The number of right-hand sides,
114+
// i.e., the number of columns in matrix B (0 ≤ nrhs).
115+
a, // Pointer to the square matrix A (n x n).
116+
lda, // The leading dimension of matrix A, must be at least max(1,
117+
// n). It must be at least max(1, n).
118+
stride_a, // Stride between consecutive A matrices in the batch.
119+
ipiv, // Pointer to the output array of pivot indices that were used
120+
// during factorization (n, ).
121+
stride_ipiv, // Stride between consecutive pivot arrays in the
122+
// batch.
123+
b, // Pointer to the matrix B of right-hand sides (ldb, nrhs).
124+
ldb, // The leading dimension of matrix B, must be at least max(1,
125+
// n).
126+
stride_b, // Stride between consecutive B matrices in the batch.
127+
batch_size, // Total number of matrices in the batch.
128+
scratchpad, // Pointer to scratchpad memory to be used by MKL
129+
// routine for storing intermediate results.
130+
scratchpad_size, depends);
131+
} catch (mkl_lapack::exception const &e) {
132+
is_exception_caught = true;
133+
info = e.info();
134+
135+
if (info < 0) {
136+
error_msg << "Parameter number " << -info
137+
<< " had an illegal value.";
138+
}
139+
else if (info == scratchpad_size && e.detail() != 0) {
140+
error_msg
141+
<< "Insufficient scratchpad size. Required size is at least "
142+
<< e.detail();
143+
}
144+
else if (info > 0) {
145+
is_exception_caught = false;
146+
if (scratchpad != nullptr) {
147+
dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad,
148+
exec_q);
149+
}
150+
throw LinAlgError("The solve could not be completed.");
151+
}
152+
else {
153+
error_msg << "Unexpected MKL exception caught during getrs() "
154+
"call:\nreason: "
155+
<< e.what() << "\ninfo: " << e.info();
156+
}
157+
} catch (sycl::exception const &e) {
158+
is_exception_caught = true;
159+
error_msg << "Unexpected SYCL exception caught during getrs() call:\n"
160+
<< e.what();
161+
}
162+
163+
if (is_exception_caught) // an unexpected error occurs
164+
{
165+
if (scratchpad != nullptr) {
166+
dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, exec_q);
167+
}
168+
169+
throw std::runtime_error(error_msg.str());
170+
}
171+
172+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
173+
cgh.depends_on(getrs_batch_event);
174+
auto ctx = exec_q.get_context();
175+
cgh.host_task([ctx, scratchpad]() {
176+
dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, ctx);
177+
});
178+
});
179+
host_task_events.push_back(clean_up_event);
180+
return getrs_batch_event;
181+
}
182+
183+
std::pair<sycl::event, sycl::event>
184+
getrs_batch(sycl::queue &exec_q,
185+
const dpctl::tensor::usm_ndarray &a_array,
186+
const dpctl::tensor::usm_ndarray &ipiv_array,
187+
const dpctl::tensor::usm_ndarray &b_array,
188+
oneapi::mkl::transpose trans,
189+
std::int64_t n,
190+
std::int64_t nrhs,
191+
std::int64_t stride_a,
192+
std::int64_t stride_ipiv,
193+
std::int64_t stride_b,
194+
std::int64_t batch_size,
195+
const std::vector<sycl::event> &depends)
196+
{
197+
const int a_array_nd = a_array.get_ndim();
198+
const int b_array_nd = b_array.get_ndim();
199+
const int ipiv_array_nd = ipiv_array.get_ndim();
200+
201+
if (a_array_nd < 3) {
202+
throw py::value_error(
203+
"The LU-factorized array has ndim=" + std::to_string(a_array_nd) +
204+
", but an array with ndim >= 3 is expected");
205+
}
206+
if (b_array_nd < 3) {
207+
throw py::value_error("The right-hand sides array has ndim=" +
208+
std::to_string(b_array_nd) +
209+
", but an array with ndim >= 3 is expected");
210+
}
211+
if (ipiv_array_nd < 1) {
212+
throw py::value_error("The array of pivot indices has ndim=" +
213+
std::to_string(ipiv_array_nd) +
214+
", but an array with ndim >= 2 is expected");
215+
}
216+
217+
if (ipiv_array_nd != a_array_nd - 1) {
218+
throw py::value_error(
219+
"The array of pivot indices has ndim=" +
220+
std::to_string(ipiv_array_nd) +
221+
", but an array with ndim=" + std::to_string(a_array_nd - 1) +
222+
" is expected to match LU batch dimensions");
223+
}
224+
225+
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
226+
227+
if (a_array_shape[a_array_nd - 1] != a_array_shape[a_array_nd - 2]) {
228+
throw py::value_error(
229+
"The last two dimensions of the LU array must be square,"
230+
" but got a shape of (" +
231+
std::to_string(a_array_shape[a_array_nd - 1]) + ", " +
232+
std::to_string(a_array_shape[a_array_nd - 2]) + ").");
233+
}
234+
235+
// check compatibility of execution queue and allocation queue
236+
if (!dpctl::utils::queues_are_compatible(exec_q,
237+
{a_array, b_array, ipiv_array}))
238+
{
239+
throw py::value_error(
240+
"Execution queue is not compatible with allocation queues");
241+
}
242+
243+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
244+
if (overlap(a_array, b_array)) {
245+
throw py::value_error("The LU-factorized and right-hand sides arrays "
246+
"are overlapping segments of memory");
247+
}
248+
249+
bool is_a_array_c_contig = a_array.is_c_contiguous();
250+
bool is_a_array_f_contig = a_array.is_f_contiguous();
251+
bool is_b_array_f_contig = b_array.is_f_contiguous();
252+
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
253+
bool is_ipiv_array_f_contig = ipiv_array.is_f_contiguous();
254+
if (!is_a_array_c_contig && !is_a_array_f_contig) {
255+
throw py::value_error("The LU-factorized array "
256+
"must be either C-contiguous "
257+
"or F-contiguous");
258+
}
259+
if (!is_b_array_f_contig) {
260+
throw py::value_error("The right-hand sides array "
261+
"must be F-contiguous");
262+
}
263+
if (!is_ipiv_array_c_contig && !is_ipiv_array_f_contig) {
264+
throw py::value_error("The array of pivot indices "
265+
"must be contiguous");
266+
}
267+
268+
auto array_types = td_ns::usm_ndarray_types();
269+
int a_array_type_id =
270+
array_types.typenum_to_lookup_id(a_array.get_typenum());
271+
int b_array_type_id =
272+
array_types.typenum_to_lookup_id(b_array.get_typenum());
273+
274+
if (a_array_type_id != b_array_type_id) {
275+
throw py::value_error("The types of the LU-factorized and "
276+
"right-hand sides arrays are mismatched");
277+
}
278+
279+
getrs_batch_impl_fn_ptr_t getrs_batch_fn =
280+
getrs_batch_dispatch_vector[a_array_type_id];
281+
if (getrs_batch_fn == nullptr) {
282+
throw py::value_error(
283+
"No getrs_batch implementation defined for the provided type "
284+
"of the input matrix.");
285+
}
286+
287+
auto ipiv_types = td_ns::usm_ndarray_types();
288+
int ipiv_array_type_id =
289+
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
290+
291+
if (ipiv_array_type_id != static_cast<int>(td_ns::typenum_t::INT64)) {
292+
throw py::value_error("The type of 'ipiv_array' must be int64.");
293+
}
294+
295+
const std::int64_t lda = std::max<size_t>(1UL, n);
296+
const std::int64_t ldb = std::max<size_t>(1UL, n);
297+
298+
char *a_array_data = a_array.get_data();
299+
char *b_array_data = b_array.get_data();
300+
char *ipiv_array_data = ipiv_array.get_data();
301+
302+
std::int64_t *ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);
303+
304+
std::vector<sycl::event> host_task_events;
305+
sycl::event getrs_batch_ev = getrs_batch_fn(
306+
exec_q, trans, n, nrhs, a_array_data, lda, stride_a, ipiv, stride_ipiv,
307+
b_array_data, ldb, stride_b, batch_size, host_task_events, depends);
308+
309+
sycl::event args_ev = dpctl::utils::keep_args_alive(
310+
exec_q, {a_array, b_array, ipiv_array}, host_task_events);
311+
312+
return std::make_pair(args_ev, getrs_batch_ev);
313+
}
314+
315+
template <typename fnT, typename T>
316+
struct GetrsBatchContigFactory
317+
{
318+
fnT get()
319+
{
320+
if constexpr (types::GetrsBatchTypePairSupportFactory<T>::is_defined) {
321+
return getrs_batch_impl<T>;
322+
}
323+
else {
324+
return nullptr;
325+
}
326+
}
327+
};
328+
329+
void init_getrs_batch_dispatch_vector(void)
330+
{
331+
td_ns::DispatchVectorBuilder<getrs_batch_impl_fn_ptr_t,
332+
GetrsBatchContigFactory, td_ns::num_types>
333+
contig;
334+
contig.populate_dispatch_vector(getrs_batch_dispatch_vector);
335+
}
336+
} // namespace dpnp::extensions::lapack

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ void init_dispatch_vectors(void)
5858
lapack_ext::init_getrf_batch_dispatch_vector();
5959
lapack_ext::init_getrf_dispatch_vector();
6060
lapack_ext::init_getri_batch_dispatch_vector();
61+
lapack_ext::init_getrs_batch_dispatch_vector();
6162
lapack_ext::init_getrs_dispatch_vector();
6263
lapack_ext::init_orgqr_batch_dispatch_vector();
6364
lapack_ext::init_orgqr_dispatch_vector();
@@ -164,12 +165,22 @@ PYBIND11_MODULE(_lapack_impl, m)
164165

165166
m.def("_getrs", &lapack_ext::getrs,
166167
"Call `getrs` from OneMKL LAPACK library to return "
167-
"the solves of linear equations with an LU-factored "
168+
"the solutions of linear equations with an LU-factored "
168169
"square coefficient matrix, with multiple right-hand sides",
169170
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
170171
py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N,
171172
py::arg("depends") = py::list());
172173

174+
m.def("_getrs_batch", &lapack_ext::getrs_batch,
175+
"Call `getrs_batch` from OneMKL LAPACK library to return "
176+
"the solutions of batch linear equations with an LU-factored "
177+
"square coefficient matrix, with multiple right-hand sides",
178+
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
179+
py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N,
180+
py::arg("n"), py::arg("nrhs"), py::arg("stride_a"),
181+
py::arg("stride_ipiv"), py::arg("stride_b"), py::arg("batch_size"),
182+
py::arg("depends") = py::list());
183+
173184
m.def("_orgqr_batch", &lapack_ext::orgqr_batch,
174185
"Call `_orgqr_batch` from OneMKL LAPACK library to return "
175186
"the real orthogonal matrix Qi of the QR factorization "

0 commit comments

Comments
 (0)