Skip to content

Commit b7ccec0

Browse files
authored
[WA] Fix the lack of exceptions vector in getrf_batch (#1916)
This PR provides a work-around for the lack of exception vector within batch_error exception, otherwise segmentation fault would happen while accessing exceptions.
1 parent b979b55 commit b7ccec0

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <ATen/ops/_linalg_check_errors_native.h>
1212
#include <ATen/ops/empty.h>
1313
#include <ATen/ops/from_blob.h>
14+
#include <ATen/ops/zeros_like.h>
1415

1516
#include <comm/SYCLContext.h>
1617
#include <comm/TensorInfo.h>
@@ -37,21 +38,37 @@ static oneapi::mkl::transpose to_blas_(TransposeType trans) {
3738
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
3839
}
3940

40-
void error_handle(int32_t* infos, const oneapi::mkl::lapack::batch_error& be) {
41+
void error_handle(
42+
int32_t* info_cpu,
43+
const oneapi::mkl::lapack::batch_error& be) {
4144
auto errs = be.exceptions();
4245
auto ids = be.ids();
46+
47+
if (!errs.size()) {
48+
TORCH_WARN(
49+
"Caught lapack exception:\nWhat: ", be.what(), "\nInfo: ", be.info());
50+
for (auto& i : ids) {
51+
TORCH_WARN("Error in matrix #", i);
52+
info_cpu[i] = 1;
53+
}
54+
return;
55+
}
56+
4357
for (auto& i : ids) {
4458
try {
4559
std::rethrow_exception(errs[i]);
4660
} catch (const oneapi::mkl::lapack::exception& e) {
47-
std::cout << "Cathed lapack exception:"
48-
<< "\nWhat: " << e.what() << "\nInfo: " << e.info()
49-
<< "\nDetail: " << e.detail() << std::endl;
50-
infos[i] = e.info();
61+
TORCH_WARN(
62+
"Caught lapack exception:\nWhat: ",
63+
e.what(),
64+
"\nInfo: ",
65+
e.info(),
66+
"\nDetail: ",
67+
e.detail());
68+
info_cpu[i] = e.info();
5169
} catch (const sycl::exception& e) {
52-
std::cout << "Catched SYCL exception:"
53-
<< "\nWhat: " << e.what() << "\nInfo: -1" << std::endl;
54-
infos[i] = -1;
70+
TORCH_WARN("Caught SYCL exception:\nWhat: ", e.what(), "\nInfo: -1");
71+
info_cpu[i] = -1;
5572
}
5673
}
5774
}
@@ -372,7 +389,7 @@ template <typename scalar_t>
372389
static void apply_lu_xpu_(
373390
const Tensor& self_,
374391
Tensor& pivots_,
375-
int32_t* infos_) {
392+
int32_t* info_data) {
376393
// do nothing if empty input.
377394
if (self_.numel() == 0)
378395
return;
@@ -403,7 +420,7 @@ static void apply_lu_xpu_(
403420
(scalar_t*)(scratchpad_at.data_ptr()),
404421
scratchpadsize);
405422
} catch (const oneapi::mkl::lapack::batch_error& be) {
406-
error_handle(infos_, be);
423+
error_handle(info_data, be);
407424
}
408425
}
409426

@@ -436,8 +453,8 @@ static void apply_lu_solve_xpu_(
436453
int64_t* ipiv = pivots.data_ptr<int64_t>();
437454
scalar_t* b = b_.data_ptr<scalar_t>();
438455

439-
std::vector<int32_t> infos(batch_size, 0);
440-
int32_t* infos_ = infos.data();
456+
std::vector<int32_t> info_vec(batch_size, 0);
457+
int32_t* info_data = info_vec.data();
441458

442459
auto execute_mkl_getrs =
443460
[&](scalar_t* a, scalar_t* b, int64_t* ipiv, int64_t batch_size) {
@@ -471,7 +488,7 @@ static void apply_lu_solve_xpu_(
471488
scratchpad_at.data_ptr<scalar_t>(),
472489
scratchpad_size);
473490
} catch (oneapi::mkl::lapack::batch_error be) {
474-
error_handle(infos_, be);
491+
error_handle(info_data, be);
475492
}
476493
};
477494

@@ -529,17 +546,18 @@ void lu_factor_mkl(
529546
"linalg.lu_factor: LU without pivoting is not implemented on the XPU");
530547

531548
// handle the info
532-
info.zero_();
533-
int32_t* infos_data = info.data_ptr<int32_t>();
549+
Tensor info_ = at::zeros_like(info, Device(at::kCPU));
550+
int32_t* info_data = info_.data_ptr<int32_t>();
534551

535552
// oneMKL requires Long for pivots but PyTorch provides Int
536553
Tensor pivots_ = at::empty(pivots.sizes(), pivots.options().dtype(kLong));
537554

538555
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_xpu", [&] {
539-
apply_lu_xpu_<scalar_t>(LU, pivots_, infos_data);
556+
apply_lu_xpu_<scalar_t>(LU, pivots_, info_data);
540557
});
541558

542-
// Copy to original pivots tensor
559+
// Copy to original info and pivots tensor
560+
info.copy_(info_);
543561
pivots.copy_(pivots_);
544562
}
545563

0 commit comments

Comments
 (0)