Skip to content

Commit 4270823

Browse files
Add more validation checks in getrs_batch
1 parent 36d21df commit 4270823

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

dpnp/backend/extensions/lapack/getrs_batch.cpp

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ static sycl::event getrs_batch_impl(sycl::queue &exec_q,
128128
scratchpad, // Pointer to scratchpad memory to be used by MKL
129129
// routine for storing intermediate results.
130130
scratchpad_size, depends);
131+
} catch (mkl_lapack::batch_error const &be) {
132+
// Get the indices of matrices within the batch that encountered an
133+
// error
134+
auto error_matrices_ids = be.ids();
135+
136+
// OneMKL batched functions throw a single `batch_error`
137+
// instead of per-matrix exceptions or an info array.
138+
// This is interpreted as a computation_error (singular matrix),
139+
// consistent with non-batched LAPACK behavior.
140+
is_exception_caught = false;
141+
if (scratchpad != nullptr) {
142+
dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, exec_q);
143+
}
144+
throw LinAlgError("The solve could not be completed.");
131145
} catch (mkl_lapack::exception const &e) {
132146
is_exception_caught = true;
133147
info = e.info();
@@ -203,17 +217,25 @@ std::pair<sycl::event, sycl::event>
203217
"The LU-factorized array has ndim=" + std::to_string(a_array_nd) +
204218
", but an array with ndim >= 3 is expected");
205219
}
206-
if (b_array_nd < 3) {
220+
if (b_array_nd < 2) {
207221
throw py::value_error("The right-hand sides array has ndim=" +
208222
std::to_string(b_array_nd) +
209-
", but an array with ndim >= 3 is expected");
223+
", but an array with ndim >= 2 is expected");
210224
}
211225
if (ipiv_array_nd < 1) {
212226
throw py::value_error("The array of pivot indices has ndim=" +
213227
std::to_string(ipiv_array_nd) +
214228
", but an array with ndim >= 2 is expected");
215229
}
216230

231+
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
232+
if (a_array_shape[0] != a_array_shape[1]) {
233+
throw py::value_error("Expected batch of square matrices , but got "
234+
"matrix shape (" +
235+
std::to_string(a_array_shape[0]) + ", " +
236+
std::to_string(a_array_shape[1]) + ") in batch");
237+
}
238+
217239
if (ipiv_array_nd != a_array_nd - 1) {
218240
throw py::value_error(
219241
"The array of pivot indices has ndim=" +
@@ -222,16 +244,6 @@ std::pair<sycl::event, sycl::event>
222244
" is expected to match LU batch dimensions");
223245
}
224246

225-
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
226-
227-
if (a_array_shape[a_array_nd - 1] != a_array_shape[a_array_nd - 2]) {
228-
throw py::value_error(
229-
"The last two dimensions of the LU array must be square,"
230-
" but got a shape of (" +
231-
std::to_string(a_array_shape[a_array_nd - 1]) + ", " +
232-
std::to_string(a_array_shape[a_array_nd - 2]) + ").");
233-
}
234-
235247
// check compatibility of execution queue and allocation queue
236248
if (!dpctl::utils::queues_are_compatible(exec_q,
237249
{a_array, b_array, ipiv_array}))
@@ -281,15 +293,15 @@ std::pair<sycl::event, sycl::event>
281293
if (getrs_batch_fn == nullptr) {
282294
throw py::value_error(
283295
"No getrs_batch implementation defined for the provided type "
284-
"of the input matrix.");
296+
"of the input matrix");
285297
}
286298

287299
auto ipiv_types = td_ns::usm_ndarray_types();
288300
int ipiv_array_type_id =
289301
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
290302

291303
if (ipiv_array_type_id != static_cast<int>(td_ns::typenum_t::INT64)) {
292-
throw py::value_error("The type of 'ipiv_array' must be int64.");
304+
throw py::value_error("The type of 'ipiv_array' must be int64");
293305
}
294306

295307
const std::int64_t lda = std::max<size_t>(1UL, n);

0 commit comments

Comments
 (0)