diff --git a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp index e9631cf9d1..34f670c182 100644 --- a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp +++ b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -40,6 +41,16 @@ static oneapi::mkl::transpose to_blas_(TransposeType trans) { void error_handle(int32_t* infos, const oneapi::mkl::lapack::batch_error& be) { auto errs = be.exceptions(); auto ids = be.ids(); + + if (!errs.size()) { + TORCH_WARN("Caught lapack exception:\nWhat: ", be.what(), "\nInfo: ", be.info()); + for (auto& i : ids) { + TORCH_WARN("Error in matrix #", i); + infos[i] = 1; + } + return; + } + for (auto& i : ids) { try { std::rethrow_exception(errs[i]); @@ -529,8 +540,8 @@ void lu_factor_mkl( "linalg.lu_factor: LU without pivoting is not implemented on the XPU"); // handle the info - info.zero_(); - int32_t* infos_data = info.data_ptr(); + Tensor info_ = at::zeros_like(info, Device(at::kCPU)); + int32_t* infos_data = info_.data_ptr(); // oneMKL requires Long for pivots but PyTorch provides Int Tensor pivots_ = at::empty(pivots.sizes(), pivots.options().dtype(kLong)); @@ -539,7 +550,8 @@ void lu_factor_mkl( apply_lu_xpu_(LU, pivots_, infos_data); }); - // Copy to original pivots tensor + // Copy to original info and pivots tensor + info.copy_(info_); pivots.copy_(pivots_); }