Skip to content

Commit b00ce69

Browse files
committed
Update
[ghstack-poisoned]
1 parent 8ec08f9 commit b00ce69

File tree

11 files changed

+379
-0
lines changed

11 files changed

+379
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,6 @@
6767
[submodule "backends/cadence/utils/FACTO"]
6868
path = backends/cadence/utils/FACTO
6969
url = https://github.com/pytorch-labs/FACTO.git
70+
[submodule "third-party/pocketfft"]
71+
path = third-party/pocketfft
72+
url = https://github.com/mreineck/pocketfft

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- op: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out
88

9+
- op: _fft_r2c.out
10+
911
- op: _linalg_det.result
1012

1113
- op: _linalg_svd.U

kernels/optimized/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ message("Generated files ${gen_command_sources}")
6060

6161
list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
6262
add_library(optimized_kernels ${_optimized_kernels__srcs})
63+
target_include_directories(optimized_kernels PRIVATE "${EXECUTORCH_ROOT}/third-party/pocketfft")
6364
target_link_libraries(
6465
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
6566
)
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/span.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
#include <pocketfft_hdronly.h>
13+
14+
namespace torch::executor::native {
15+
16+
// TODO: contents of this anonymous namespace are copy/pasted from
17+
// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small
18+
// portions (the parts that don't depend on Tensor) could be reused;
19+
// refactor to enable that once we can share headers from PyTorch
20+
// core.
21+
namespace {
22+
pocketfft::stride_t stride_from_tensor(const Tensor& t) {
23+
pocketfft::stride_t stride(t.strides().begin(), t.strides().end());
24+
for (auto& s : stride) {
25+
s *= t.element_size();
26+
}
27+
return stride;
28+
}
29+
30+
pocketfft::shape_t shape_from_tensor(const Tensor& t) {
31+
return pocketfft::shape_t(t.sizes().begin(), t.sizes().end());
32+
}
33+
34+
// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what
35+
// PyTorch core does and I'm not aware of a portable way to do this
36+
// that doesn't rely on UB.
37+
template <typename T>
38+
inline std::complex<T>* tensor_cdata(Tensor& t) {
39+
return reinterpret_cast<std::complex<T>*>(
40+
t.data_ptr<executorch::runtime::etensor::complex<T>>());
41+
}
42+
43+
template <typename T>
44+
inline const std::complex<T>* tensor_cdata(const Tensor& t) {
45+
return reinterpret_cast<const std::complex<T>*>(
46+
t.const_data_ptr<executorch::runtime::etensor::complex<T>>());
47+
}
48+
49+
// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and
50+
// could be shared immediately.
51+
enum class fft_norm_mode {
52+
none, // No normalization
53+
by_root_n, // Divide by sqrt(signal_size)
54+
by_n, // Divide by signal_size
55+
};
56+
57+
// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK;
58+
// upstream with TORCH_CHECK will be fine to use once we have code
59+
// sharing.
60+
template <typename T>
61+
std::optional<T>
62+
compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) {
63+
constexpr auto one = static_cast<T>(1);
64+
switch (static_cast<fft_norm_mode>(normalization)) {
65+
case fft_norm_mode::none:
66+
return one;
67+
case fft_norm_mode::by_n:
68+
return one / static_cast<T>(size);
69+
case fft_norm_mode::by_root_n:
70+
return one / std::sqrt(static_cast<T>(size));
71+
}
72+
ET_KERNEL_CHECK_MSG(
73+
ctx,
74+
false,
75+
InvalidArgument,
76+
std::nullopt,
77+
"Unsupported normalization type: %" PRId64,
78+
normalization);
79+
}
80+
81+
template <typename T>
82+
std::optional<T> compute_fct(
83+
KernelRuntimeContext& ctx,
84+
const Tensor& t,
85+
IntArrayRef dim,
86+
int64_t normalization) {
87+
if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
88+
return static_cast<T>(1);
89+
}
90+
const auto& sizes = t.sizes();
91+
int64_t n = 1;
92+
for (auto idx : dim) {
93+
n *= sizes[idx];
94+
}
95+
return compute_fct<T>(ctx, n, normalization);
96+
}
97+
98+
} // namespace
99+
100+
Tensor& opt_fft_r2c_out(
101+
KernelRuntimeContext& ctx,
102+
const Tensor& in,
103+
IntArrayRef dim,
104+
int64_t normalization,
105+
bool onesided,
106+
Tensor& out) {
107+
auto in_sizes = in.sizes();
108+
ET_KERNEL_CHECK(ctx, in.dim() <= kTensorDimensionLimit, InvalidArgument, out);
109+
110+
std::array<Tensor::SizesType, kTensorDimensionLimit> out_sizes_storage;
111+
executorch::runtime::Span<Tensor::SizesType> out_sizes(
112+
out_sizes_storage.data(), in_sizes.size());
113+
std::copy(in_sizes.begin(), in_sizes.end(), out_sizes.begin());
114+
ET_KERNEL_CHECK(ctx, !dim.empty(), InvalidArgument, out);
115+
116+
ET_KERNEL_CHECK(
117+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
118+
119+
ET_KERNEL_CHECK_MSG(
120+
ctx,
121+
onesided,
122+
InvalidArgument,
123+
out,
124+
"onesided=False is not supported yet in _fft_r2c");
125+
126+
ET_KERNEL_CHECK_MSG(
127+
ctx,
128+
out.scalar_type() == executorch::runtime::toComplexType(in.scalar_type()),
129+
InvalidArgument,
130+
out,
131+
"the output type for _fft_r2c must be the Complex type corresponding to the input type");
132+
133+
for (auto d : dim) {
134+
ET_KERNEL_CHECK_MSG(
135+
ctx,
136+
d >= 0 && d < in.dim(),
137+
InvalidArgument,
138+
out,
139+
"dims must be in bounds (got %" PRId64 ")",
140+
d);
141+
}
142+
143+
if (onesided) {
144+
out_sizes[dim.back()] = out_sizes[dim.back()] / 2 + 1;
145+
}
146+
ET_KERNEL_CHECK_MSG(
147+
ctx,
148+
resize_tensor(
149+
out,
150+
executorch::runtime::ArrayRef<Tensor::SizesType>(
151+
out_sizes.data(), out_sizes.size())) == Error::Ok,
152+
InvalidArgument,
153+
out,
154+
"Failed to resize output tensor (last dim %d).",
155+
out_sizes[dim.back()]);
156+
157+
pocketfft::shape_t axes(dim.begin(), dim.end());
158+
auto in_shape = shape_from_tensor(in);
159+
// TODO: if arbitrary strides are a possibility, we need to validate
160+
// these, because pocketfft README says "Strides that lead to
161+
// multiple accesses of the same memory address are not allowed."
162+
auto in_stride = stride_from_tensor(in);
163+
auto out_stride = stride_from_tensor(out);
164+
// NOTE: as of this writing, upstream PyTorch only supports
165+
// float/double, so we follow suit.
166+
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "_fft_r2c.out", CTYPE_IN, [&] {
167+
auto fct = compute_fct<CTYPE_IN>(ctx, in, dim, normalization);
168+
if (!fct) {
169+
// Check failed, just bail out of the lambda.
170+
return;
171+
}
172+
pocketfft::r2c<CTYPE_IN>(
173+
in_shape,
174+
in_stride,
175+
out_stride,
176+
axes,
177+
true,
178+
in.const_data_ptr<CTYPE_IN>(),
179+
tensor_cdata<CTYPE_IN>(out),
180+
*fct);
181+
182+
// TODO: fill with conjugate symmetry if not onesided; see
183+
// ATen/native/mkl/SpectralOps.cpp
184+
});
185+
return out;
186+
}
187+
} // namespace torch::executor::native

kernels/optimized/cpu/targets.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ _OPTIMIZED_ATEN_OPS = (
2525
],
2626
),
2727
op_target(name = "op_exp"),
28+
op_target(
29+
name = "op_fft_r2c",
30+
deps = [] if runtime.is_oss else ["fbsource//third-party/pocketfft:pocketfft"],
31+
),
2832
op_target(name = "op_sigmoid"),
2933
op_target(
3034
name = "op_gelu",

kernels/optimized/optimized-oss.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
# log_softmax, due to the OSS build not currently including sleef.
66
# TODO (T183193812)
77

8+
- op: _fft_r2c.out
9+
kernels:
10+
- arg_meta: null
11+
kernel_name: torch::executor::opt_fft_r2c_out
12+
813
- op: add.out
914
kernels:
1015
- arg_meta: null

kernels/optimized/optimized.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
#
33
# This yaml file contains operators that have optimized kernels available.
44

5+
- op: _fft_r2c.out
6+
kernels:
7+
- arg_meta: null
8+
kernel_name: torch::executor::opt_fft_r2c_out
9+
510
- op: _log_softmax.out
611
kernels:
712
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ set(_optimized_kernels_test_sources
265265
"op_bmm_test.cpp"
266266
"op_div_test.cpp"
267267
"op_exp_test.cpp"
268+
"op_fft_r2c_test.cpp"
268269
"op_gelu_test.cpp"
269270
"op_le_test.cpp"
270271
"op_log_softmax_test.cpp"

0 commit comments

Comments
 (0)