Skip to content

Commit aa6954b

Browse files
authored
Enhanced fft_r2c overhead and reactivate fft_r2c on XPU (#1720)
# Motivation Revert `fft_r2c` fallback after root causing the `TestInductorOpInfoXPU` failure. There is a memory layout mismatching between `aten::fft_r2c` and Inductor meta deducing. We will have another PR to correct Inductor meta deducing for XPU backend. # Solution Inconsistent strides should be resolved in `fft_r2c` meta calculation instead of `fft_r2c` XPU kernel. The change of this PR includes: - Reduced `_fft_r2c_mkl` overhead and simplified kernel code - Reactivate `fft_r2c` on XPU
1 parent 337deed commit aa6954b

File tree

2 files changed

+18
-21
lines changed

2 files changed

+18
-21
lines changed

src/ATen/native/xpu/SpectralOps.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
#include <ATen/native/Resize.h>
2-
#include <ATen/ops/_fft_r2c_native.h>
31
#if defined(USE_ONEMKL_XPU)
42
#include <ATen/native/xpu/mkl/SpectralOps.h>
53
#else
4+
#include <ATen/native/Resize.h>
65
#include <ATen/ops/_fft_c2c_native.h>
76
#include <ATen/ops/_fft_c2r_native.h>
7+
#include <ATen/ops/_fft_r2c_native.h>
88
#endif // USE_ONEMKL_XPU
99

1010
namespace at::native {
@@ -87,9 +87,13 @@ Tensor _fft_r2c_xpu(
8787
bool onesided) {
8888
TORCH_CHECK(self.is_floating_point());
8989

90+
#if defined(USE_ONEMKL_XPU)
91+
return native::xpu::_fft_r2c_mkl(self, dim, normalization, onesided);
92+
#else
9093
Tensor out_cpu = native::_fft_r2c_mkl(
9194
self.to(Device(at::kCPU)), dim, normalization, onesided);
9295
return out_cpu.to(Device(at::kXPU));
96+
#endif // USE_ONEMKL_XPU
9397
}
9498

9599
Tensor& _fft_r2c_xpu_out(
@@ -100,11 +104,15 @@ Tensor& _fft_r2c_xpu_out(
100104
Tensor& out) {
101105
TORCH_CHECK(self.is_floating_point());
102106

107+
#if defined(USE_ONEMKL_XPU)
108+
return native::xpu::_fft_r2c_mkl_out(self, dim, normalization, onesided, out);
109+
#else
103110
Tensor out_cpu = native::_fft_r2c_mkl(
104111
self.to(Device(at::kCPU)), dim, normalization, onesided);
105112
at::native::resize_output(out, out_cpu.sizes());
106113
out.copy_(out_cpu);
107114
return out;
115+
#endif // USE_ONEMKL_XPU
108116
}
109117

110118
} // namespace at::native

src/ATen/native/xpu/mkl/SpectralOps.cpp

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -499,11 +499,10 @@ Tensor _fft_r2c_mkl(
499499

500500
IntArrayRef out_sizes = onesided ? onesided_sizes : input_sizes;
501501

502-
auto sorted_dims = impl::_sort_dims(self, dim, /*exclude_last=*/true);
503502
auto out = at::empty(
504503
out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
505504

506-
auto working_tensor = self.clone(MemoryFormat::Contiguous);
505+
auto working_tensor = self.contiguous();
507506

508507
// First do the R2C transform on the last dimension
509508
impl::_exec_fft(
@@ -515,17 +514,12 @@ Tensor _fft_r2c_mkl(
515514
self.options().dtype(c10::toComplexType(self.scalar_type())));
516515
}
517516

518-
sorted_dims.resize(sorted_dims.size() - 1);
517+
DimVector sorted_dims(dim.begin(), dim.end() - 1);
519518

520519
while (!sorted_dims.empty()) {
521-
if (working_tensor.is_same(self)) {
522-
working_tensor = std::move(out);
523-
out = at::empty(
524-
out_sizes,
525-
self.options().dtype(c10::toComplexType(self.scalar_type())));
526-
} else {
527-
std::swap(out, working_tensor);
528-
}
520+
sorted_dims = impl::_sort_dims(self, sorted_dims);
521+
522+
std::swap(out, working_tensor);
529523

530524
const auto max_dims =
531525
std::min(static_cast<size_t>(impl::mkl_max_ndim), sorted_dims.size());
@@ -539,18 +533,13 @@ Tensor _fft_r2c_mkl(
539533
onesided,
540534
/*forward=*/true);
541535
sorted_dims.resize(sorted_dims.size() - max_dims);
542-
543-
if (sorted_dims.empty()) {
544-
break;
545-
}
546-
547-
sorted_dims = impl::_sort_dims(self, sorted_dims);
548536
}
549537

550538
// Only need to normalize the onesided slice since data in the other half is
551539
// overwritten
552540
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
553-
working_tensor = self;
541+
impl::_fft_apply_normalization(out_slice, normalization, input_sizes, dim);
542+
554543
if (!onesided) {
555544
if (out.sizes()[last_dim] != out_sizes[last_dim]) {
556545
working_tensor.resize_(out_sizes, MemoryFormat::Contiguous);
@@ -560,7 +549,7 @@ Tensor _fft_r2c_mkl(
560549
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
561550
}
562551

563-
return impl::_fft_apply_normalization(out, normalization, input_sizes, dim);
552+
return out;
564553
}
565554

566555
Tensor& _fft_r2c_mkl_out(

0 commit comments

Comments
 (0)