@@ -128,6 +128,20 @@ static sycl::event getrs_batch_impl(sycl::queue &exec_q,
128128 scratchpad, // Pointer to scratchpad memory to be used by MKL
129129 // routine for storing intermediate results.
130130 scratchpad_size, depends);
131+ } catch (mkl_lapack::batch_error const &be) {
132+ // Get the indices of matrices within the batch that encountered an
133+ // error
134+ auto error_matrices_ids = be.ids ();
135+
136+ // OneMKL batched functions throw a single `batch_error`
137+ // instead of per-matrix exceptions or an info array.
138+ // This is interpreted as a computation_error (singular matrix),
139+ // consistent with non-batched LAPACK behavior.
140+ is_exception_caught = false ;
141+ if (scratchpad != nullptr ) {
142+ dpctl::tensor::alloc_utils::sycl_free_noexcept (scratchpad, exec_q);
143+ }
144+ throw LinAlgError (" The solve could not be completed." );
131145 } catch (mkl_lapack::exception const &e) {
132146 is_exception_caught = true ;
133147 info = e.info ();
@@ -203,17 +217,25 @@ std::pair<sycl::event, sycl::event>
203217 " The LU-factorized array has ndim=" + std::to_string (a_array_nd) +
204218 " , but an array with ndim >= 3 is expected" );
205219 }
206- if (b_array_nd < 3 ) {
220+ if (b_array_nd < 2 ) {
207221 throw py::value_error (" The right-hand sides array has ndim=" +
208222 std::to_string (b_array_nd) +
209- " , but an array with ndim >= 3 is expected" );
223+ " , but an array with ndim >= 2 is expected" );
210224 }
211225 if (ipiv_array_nd < 1 ) {
212226 throw py::value_error (" The array of pivot indices has ndim=" +
213227 std::to_string (ipiv_array_nd) +
214228 " , but an array with ndim >= 2 is expected" );
215229 }
216230
231+ const py::ssize_t *a_array_shape = a_array.get_shape_raw ();
232+ if (a_array_shape[0 ] != a_array_shape[1 ]) {
233+ throw py::value_error (" Expected batch of square matrices , but got "
234+ " matrix shape (" +
235+ std::to_string (a_array_shape[0 ]) + " , " +
236+ std::to_string (a_array_shape[1 ]) + " ) in batch" );
237+ }
238+
217239 if (ipiv_array_nd != a_array_nd - 1 ) {
218240 throw py::value_error (
219241 " The array of pivot indices has ndim=" +
@@ -222,16 +244,6 @@ std::pair<sycl::event, sycl::event>
222244 " is expected to match LU batch dimensions" );
223245 }
224246
225- const py::ssize_t *a_array_shape = a_array.get_shape_raw ();
226-
227- if (a_array_shape[a_array_nd - 1 ] != a_array_shape[a_array_nd - 2 ]) {
228- throw py::value_error (
229- " The last two dimensions of the LU array must be square,"
230- " but got a shape of (" +
231- std::to_string (a_array_shape[a_array_nd - 1 ]) + " , " +
232- std::to_string (a_array_shape[a_array_nd - 2 ]) + " )." );
233- }
234-
235247 // check compatibility of execution queue and allocation queue
236248 if (!dpctl::utils::queues_are_compatible (exec_q,
237249 {a_array, b_array, ipiv_array}))
@@ -281,15 +293,15 @@ std::pair<sycl::event, sycl::event>
281293 if (getrs_batch_fn == nullptr ) {
282294 throw py::value_error (
283295 " No getrs_batch implementation defined for the provided type "
284- " of the input matrix. " );
296+ " of the input matrix" );
285297 }
286298
287299 auto ipiv_types = td_ns::usm_ndarray_types ();
288300 int ipiv_array_type_id =
289301 ipiv_types.typenum_to_lookup_id (ipiv_array.get_typenum ());
290302
291303 if (ipiv_array_type_id != static_cast <int >(td_ns::typenum_t ::INT64)) {
292- throw py::value_error (" The type of 'ipiv_array' must be int64. " );
304+ throw py::value_error (" The type of 'ipiv_array' must be int64" );
293305 }
294306
295307 const std::int64_t lda = std::max<size_t >(1UL , n);
0 commit comments