Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/ops/_linalg_check_errors_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/zeros_like.h>

#include <comm/SYCLContext.h>
#include <comm/TensorInfo.h>
Expand All @@ -37,21 +38,37 @@ static oneapi::mkl::transpose to_blas_(TransposeType trans) {
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
}

void error_handle(int32_t* infos, const oneapi::mkl::lapack::batch_error& be) {
void error_handle(
int32_t* info_cpu,
const oneapi::mkl::lapack::batch_error& be) {
auto errs = be.exceptions();
auto ids = be.ids();

if (!errs.size()) {
TORCH_WARN(
"Caught lapack exception:\nWhat: ", be.what(), "\nInfo: ", be.info());
for (auto& i : ids) {
TORCH_WARN("Error in matrix #", i);
info_cpu[i] = 1;
}
return;
}

for (auto& i : ids) {
try {
std::rethrow_exception(errs[i]);
} catch (const oneapi::mkl::lapack::exception& e) {
std::cout << "Cathed lapack exception:"
<< "\nWhat: " << e.what() << "\nInfo: " << e.info()
<< "\nDetail: " << e.detail() << std::endl;
infos[i] = e.info();
TORCH_WARN(
"Caught lapack exception:\nWhat: ",
e.what(),
"\nInfo: ",
e.info(),
"\nDetail: ",
e.detail());
info_cpu[i] = e.info();
} catch (const sycl::exception& e) {
std::cout << "Catched SYCL exception:"
<< "\nWhat: " << e.what() << "\nInfo: -1" << std::endl;
infos[i] = -1;
TORCH_WARN("Caught SYCL exception:\nWhat: ", e.what(), "\nInfo: -1");
info_cpu[i] = -1;
}
}
}
Expand Down Expand Up @@ -372,7 +389,7 @@ template <typename scalar_t>
static void apply_lu_xpu_(
const Tensor& self_,
Tensor& pivots_,
int32_t* infos_) {
int32_t* info_data) {
// do nothing if empty input.
if (self_.numel() == 0)
return;
Expand Down Expand Up @@ -403,7 +420,7 @@ static void apply_lu_xpu_(
(scalar_t*)(scratchpad_at.data_ptr()),
scratchpadsize);
} catch (const oneapi::mkl::lapack::batch_error& be) {
error_handle(infos_, be);
error_handle(info_data, be);
}
}

Expand Down Expand Up @@ -436,8 +453,8 @@ static void apply_lu_solve_xpu_(
int64_t* ipiv = pivots.data_ptr<int64_t>();
scalar_t* b = b_.data_ptr<scalar_t>();

std::vector<int32_t> infos(batch_size, 0);
int32_t* infos_ = infos.data();
std::vector<int32_t> info_vec(batch_size, 0);
int32_t* info_data = info_vec.data();

auto execute_mkl_getrs =
[&](scalar_t* a, scalar_t* b, int64_t* ipiv, int64_t batch_size) {
Expand Down Expand Up @@ -471,7 +488,7 @@ static void apply_lu_solve_xpu_(
scratchpad_at.data_ptr<scalar_t>(),
scratchpad_size);
} catch (oneapi::mkl::lapack::batch_error be) {
error_handle(infos_, be);
error_handle(info_data, be);
}
};

Expand Down Expand Up @@ -529,17 +546,18 @@ void lu_factor_mkl(
"linalg.lu_factor: LU without pivoting is not implemented on the XPU");

// handle the info
info.zero_();
int32_t* infos_data = info.data_ptr<int32_t>();
Tensor info_ = at::zeros_like(info, Device(at::kCPU));
int32_t* info_data = info_.data_ptr<int32_t>();

// oneMKL requires Long for pivots but PyTorch provides Int
Tensor pivots_ = at::empty(pivots.sizes(), pivots.options().dtype(kLong));

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_xpu", [&] {
apply_lu_xpu_<scalar_t>(LU, pivots_, infos_data);
apply_lu_xpu_<scalar_t>(LU, pivots_, info_data);
});

// Copy to original pivots tensor
// Copy to original info and pivots tensor
info.copy_(info_);
pivots.copy_(pivots_);
}

Expand Down