Skip to content

Commit 7714cf8

Browse files
committed
Fix the lack of exceptions vector in getrf_batch
1 parent f2bcd8a commit 7714cf8

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <ATen/ops/_linalg_check_errors_native.h>
1212
#include <ATen/ops/empty.h>
1313
#include <ATen/ops/from_blob.h>
14+
#include <ATen/ops/zeros_like.h>
1415

1516
#include <comm/SYCLContext.h>
1617
#include <comm/TensorInfo.h>
@@ -40,6 +41,18 @@ static oneapi::mkl::transpose to_blas_(TransposeType trans) {
4041
void error_handle(int32_t* infos, const oneapi::mkl::lapack::batch_error& be) {
4142
auto errs = be.exceptions();
4243
auto ids = be.ids();
44+
45+
if (!errs.size()) {
46+
std::cout << "Cathed lapack exception:"
47+
<< "\nWhat: " << be.what() << "\nInfo: " << be.info()
48+
<< std::endl;
49+
for (auto& i : ids) {
50+
std::cout << "Error in martix #" << i << std::endl;
51+
infos[i] = 1;
52+
}
53+
return;
54+
}
55+
4356
for (auto& i : ids) {
4457
try {
4558
std::rethrow_exception(errs[i]);
@@ -529,8 +542,8 @@ void lu_factor_mkl(
529542
"linalg.lu_factor: LU without pivoting is not implemented on the XPU");
530543

531544
// handle the info
532-
info.zero_();
533-
int32_t* infos_data = info.data_ptr<int32_t>();
545+
Tensor info_ = at::zeros_like(info, Device(at::kCPU));
546+
int32_t* infos_data = info_.data_ptr<int32_t>();
534547

535548
// oneMKL requires Long for pivots but PyTorch provides Int
536549
Tensor pivots_ = at::empty(pivots.sizes(), pivots.options().dtype(kLong));
@@ -539,7 +552,8 @@ void lu_factor_mkl(
539552
apply_lu_xpu_<scalar_t>(LU, pivots_, infos_data);
540553
});
541554

542-
// Copy to original pivots tensor
555+
// Copy to original info and pivots tensor
556+
info.copy_(info_);
543557
pivots.copy_(pivots_);
544558
}
545559

0 commit comments

Comments
 (0)