Skip to content

Commit 35648db

Browse files
authored
Refine fft and lu_solve kernels on XPU device (#1757)
This PR is to refine fft and lu_solve kernels: - Removed the unnecessary `event.wait_and_throw()` call in `_mkl_dft`, relying instead on `queue.throw_asynchronous()` for error propagation without blocking. - Add error handling for batched LU solve operations.
1 parent 5888b49 commit 35648db

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,9 @@ static void apply_lu_xpu_(
409409

410410
template <typename scalar_t>
411411
static void apply_lu_solve_xpu_(
412-
const Tensor& b_,
413412
const Tensor& lu_,
414413
const Tensor& pivots_,
414+
const Tensor& b_,
415415
TransposeType t) {
416416
// do nothing if empty input
417417
if (lu_.numel() == 0)
@@ -436,6 +436,9 @@ static void apply_lu_solve_xpu_(
436436
int64_t* ipiv = pivots.data_ptr<int64_t>();
437437
scalar_t* b = b_.data_ptr<scalar_t>();
438438

439+
std::vector<int32_t> infos(batch_size, 0);
440+
int32_t* infos_ = infos.data();
441+
439442
auto execute_mkl_getrs =
440443
[&](scalar_t* a, scalar_t* b, int64_t* ipiv, int64_t batch_size) {
441444
int64_t scratchpad_size = mkl_getrs_scratchpad<scalar_t>(
@@ -450,22 +453,26 @@ static void apply_lu_solve_xpu_(
450453
stride_b,
451454
batch_size);
452455
Tensor scratchpad_at = at::empty({scratchpad_size}, b_.options());
453-
mkl_getrs<scalar_t>(
454-
queue,
455-
trans,
456-
n,
457-
nrhs,
458-
a,
459-
lda,
460-
stride_a,
461-
ipiv,
462-
stride_ipiv,
463-
b,
464-
ldb,
465-
stride_b,
466-
batch_size,
467-
scratchpad_at.data_ptr<scalar_t>(),
468-
scratchpad_size);
456+
try {
457+
mkl_getrs<scalar_t>(
458+
queue,
459+
trans,
460+
n,
461+
nrhs,
462+
a,
463+
lda,
464+
stride_a,
465+
ipiv,
466+
stride_ipiv,
467+
b,
468+
ldb,
469+
stride_b,
470+
batch_size,
471+
scratchpad_at.data_ptr<scalar_t>(),
472+
scratchpad_size);
473+
} catch (oneapi::mkl::lapack::batch_error be) {
474+
error_handle(infos_, be);
475+
}
469476
};
470477

471478
bool is_broadcast = false;
@@ -503,7 +510,7 @@ void lu_solve_mkl(
503510
const Tensor& B,
504511
TransposeType trans) {
505512
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_xpu", [&] {
506-
apply_lu_solve_xpu_<scalar_t>(B, LU, pivots, trans);
513+
apply_lu_solve_xpu_<scalar_t>(LU, pivots, B, trans);
507514
});
508515
}
509516

src/ATen/native/xpu/mkl/SpectralOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ void _mkl_dft(
118118
} else {
119119
event = compute_backward(desc, in_data, out_data);
120120
}
121-
event.wait_and_throw();
122121
queue.throw_asynchronous();
123122
}
124123

0 commit comments

Comments
 (0)