| 
 | 1 | +/*  | 
 | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 3 | + * All rights reserved.  | 
 | 4 | + *  | 
 | 5 | + * This source code is licensed under the BSD-style license found in the  | 
 | 6 | + * LICENSE file in the root directory of this source tree.  | 
 | 7 | + */  | 
 | 8 | + | 
 | 9 | +#include <executorch/runtime/core/span.h>  | 
 | 10 | +#include <executorch/runtime/kernel/kernel_includes.h>  | 
 | 11 | + | 
 | 12 | +#include <pocketfft_hdronly.h>  | 
 | 13 | + | 
 | 14 | +#include <optional>  | 
 | 15 | + | 
 | 16 | +namespace torch::executor::native {  | 
 | 17 | + | 
 | 18 | +// TODO: contents of this anonymous namespace are copy/pasted from  | 
 | 19 | +// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small  | 
 | 20 | +// portions (the parts that don't depend on Tensor) could be reused;  | 
 | 21 | +// refactor to enable that once we can share headers from PyTorch  | 
 | 22 | +// core.  | 
 | 23 | +namespace {  | 
 | 24 | +pocketfft::stride_t stride_from_tensor(const Tensor& t) {  | 
 | 25 | +  pocketfft::stride_t stride(t.strides().begin(), t.strides().end());  | 
 | 26 | +  for (auto& s : stride) {  | 
 | 27 | +    s *= t.element_size();  | 
 | 28 | +  }  | 
 | 29 | +  return stride;  | 
 | 30 | +}  | 
 | 31 | + | 
 | 32 | +pocketfft::shape_t shape_from_tensor(const Tensor& t) {  | 
 | 33 | +  return pocketfft::shape_t(t.sizes().begin(), t.sizes().end());  | 
 | 34 | +}  | 
 | 35 | + | 
 | 36 | +// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what  | 
 | 37 | +// PyTorch core does and I'm not aware of a portable way to do this  | 
 | 38 | +// that doesn't rely on UB.  | 
 | 39 | +template <typename T>  | 
 | 40 | +inline std::complex<T>* tensor_cdata(Tensor& t) {  | 
 | 41 | +  return reinterpret_cast<std::complex<T>*>(  | 
 | 42 | +      t.data_ptr<executorch::runtime::etensor::complex<T>>());  | 
 | 43 | +}  | 
 | 44 | + | 
 | 45 | +template <typename T>  | 
 | 46 | +inline const std::complex<T>* tensor_cdata(const Tensor& t) {  | 
 | 47 | +  return reinterpret_cast<const std::complex<T>*>(  | 
 | 48 | +      t.const_data_ptr<executorch::runtime::etensor::complex<T>>());  | 
 | 49 | +}  | 
 | 50 | + | 
 | 51 | +// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and  | 
 | 52 | +// could be shared immediately.  | 
 | 53 | +enum class fft_norm_mode {  | 
 | 54 | +  none, // No normalization  | 
 | 55 | +  by_root_n, // Divide by sqrt(signal_size)  | 
 | 56 | +  by_n, // Divide by signal_size  | 
 | 57 | +};  | 
 | 58 | + | 
 | 59 | +// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK;  | 
 | 60 | +// upstream with TORCH_CHECK will be fine to use once we have code  | 
 | 61 | +// sharing.  | 
 | 62 | +template <typename T>  | 
 | 63 | +std::optional<T>  | 
 | 64 | +compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) {  | 
 | 65 | +  constexpr auto one = static_cast<T>(1);  | 
 | 66 | +  switch (static_cast<fft_norm_mode>(normalization)) {  | 
 | 67 | +    case fft_norm_mode::none:  | 
 | 68 | +      return one;  | 
 | 69 | +    case fft_norm_mode::by_n:  | 
 | 70 | +      return one / static_cast<T>(size);  | 
 | 71 | +    case fft_norm_mode::by_root_n:  | 
 | 72 | +      return one / std::sqrt(static_cast<T>(size));  | 
 | 73 | +  }  | 
 | 74 | +  ET_KERNEL_CHECK_MSG(  | 
 | 75 | +      ctx,  | 
 | 76 | +      false,  | 
 | 77 | +      InvalidArgument,  | 
 | 78 | +      std::nullopt,  | 
 | 79 | +      "Unsupported normalization type: %" PRId64,  | 
 | 80 | +      normalization);  | 
 | 81 | +}  | 
 | 82 | + | 
 | 83 | +template <typename T>  | 
 | 84 | +std::optional<T> compute_fct(  | 
 | 85 | +    KernelRuntimeContext& ctx,  | 
 | 86 | +    const Tensor& t,  | 
 | 87 | +    IntArrayRef dim,  | 
 | 88 | +    int64_t normalization) {  | 
 | 89 | +  if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {  | 
 | 90 | +    return static_cast<T>(1);  | 
 | 91 | +  }  | 
 | 92 | +  const auto& sizes = t.sizes();  | 
 | 93 | +  int64_t n = 1;  | 
 | 94 | +  for (auto idx : dim) {  | 
 | 95 | +    n *= sizes[idx];  | 
 | 96 | +  }  | 
 | 97 | +  return compute_fct<T>(ctx, n, normalization);  | 
 | 98 | +}  | 
 | 99 | + | 
 | 100 | +} // namespace  | 
 | 101 | + | 
 | 102 | +Tensor& opt_fft_r2c_out(  | 
 | 103 | +    KernelRuntimeContext& ctx,  | 
 | 104 | +    const Tensor& in,  | 
 | 105 | +    IntArrayRef dim,  | 
 | 106 | +    int64_t normalization,  | 
 | 107 | +    bool onesided,  | 
 | 108 | +    Tensor& out) {  | 
 | 109 | +  auto in_sizes = in.sizes();  | 
 | 110 | +  ET_KERNEL_CHECK(ctx, in.dim() <= kTensorDimensionLimit, InvalidArgument, out);  | 
 | 111 | + | 
 | 112 | +  std::array<Tensor::SizesType, kTensorDimensionLimit> out_sizes_storage;  | 
 | 113 | +  executorch::runtime::Span<Tensor::SizesType> out_sizes(  | 
 | 114 | +      out_sizes_storage.data(), in_sizes.size());  | 
 | 115 | +  std::copy(in_sizes.begin(), in_sizes.end(), out_sizes.begin());  | 
 | 116 | +  ET_KERNEL_CHECK(ctx, !dim.empty(), InvalidArgument, out);  | 
 | 117 | + | 
 | 118 | +  ET_KERNEL_CHECK(  | 
 | 119 | +      ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);  | 
 | 120 | + | 
 | 121 | +  ET_KERNEL_CHECK_MSG(  | 
 | 122 | +      ctx,  | 
 | 123 | +      onesided,  | 
 | 124 | +      InvalidArgument,  | 
 | 125 | +      out,  | 
 | 126 | +      "onesided=False is not supported yet in _fft_r2c");  | 
 | 127 | + | 
 | 128 | +  ET_KERNEL_CHECK_MSG(  | 
 | 129 | +      ctx,  | 
 | 130 | +      out.scalar_type() == executorch::runtime::toComplexType(in.scalar_type()),  | 
 | 131 | +      InvalidArgument,  | 
 | 132 | +      out,  | 
 | 133 | +      "the output type for _fft_r2c must be the Complex type corresponding to the input type");  | 
 | 134 | + | 
 | 135 | +  for (auto d : dim) {  | 
 | 136 | +    ET_KERNEL_CHECK_MSG(  | 
 | 137 | +        ctx,  | 
 | 138 | +        d >= 0 && d < in.dim(),  | 
 | 139 | +        InvalidArgument,  | 
 | 140 | +        out,  | 
 | 141 | +        "dims must be in bounds (got %" PRId64 ")",  | 
 | 142 | +        d);  | 
 | 143 | +  }  | 
 | 144 | + | 
 | 145 | +  if (onesided) {  | 
 | 146 | +    out_sizes[dim.back()] = out_sizes[dim.back()] / 2 + 1;  | 
 | 147 | +  }  | 
 | 148 | +  ET_KERNEL_CHECK_MSG(  | 
 | 149 | +      ctx,  | 
 | 150 | +      resize_tensor(  | 
 | 151 | +          out,  | 
 | 152 | +          executorch::runtime::ArrayRef<Tensor::SizesType>(  | 
 | 153 | +              out_sizes.data(), out_sizes.size())) == Error::Ok,  | 
 | 154 | +      InvalidArgument,  | 
 | 155 | +      out,  | 
 | 156 | +      "Failed to resize output tensor (last dim %d).",  | 
 | 157 | +      out_sizes[dim.back()]);  | 
 | 158 | + | 
 | 159 | +  pocketfft::shape_t axes(dim.begin(), dim.end());  | 
 | 160 | +  auto in_shape = shape_from_tensor(in);  | 
 | 161 | +  // TODO: if arbitrary strides are a possibility, we need to validate  | 
 | 162 | +  // these, because pocketfft README says "Strides that lead to  | 
 | 163 | +  // multiple accesses of the same memory address are not allowed."  | 
 | 164 | +  auto in_stride = stride_from_tensor(in);  | 
 | 165 | +  auto out_stride = stride_from_tensor(out);  | 
 | 166 | +  // NOTE: as of this writing, upstream PyTorch only supports  | 
 | 167 | +  // float/double, so we follow suit.  | 
 | 168 | +  ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "_fft_r2c.out", CTYPE_IN, [&] {  | 
 | 169 | +    auto fct = compute_fct<CTYPE_IN>(ctx, in, dim, normalization);  | 
 | 170 | +    if (!fct) {  | 
 | 171 | +      // Check failed, just bail out of the lambda.  | 
 | 172 | +      return;  | 
 | 173 | +    }  | 
 | 174 | +    pocketfft::r2c<CTYPE_IN>(  | 
 | 175 | +        in_shape,  | 
 | 176 | +        in_stride,  | 
 | 177 | +        out_stride,  | 
 | 178 | +        axes,  | 
 | 179 | +        true,  | 
 | 180 | +        in.const_data_ptr<CTYPE_IN>(),  | 
 | 181 | +        tensor_cdata<CTYPE_IN>(out),  | 
 | 182 | +        *fct);  | 
 | 183 | + | 
 | 184 | +    // TODO: fill with conjugate symmetry if not onesided; see  | 
 | 185 | +    // ATen/native/mkl/SpectralOps.cpp  | 
 | 186 | +  });  | 
 | 187 | +  return out;  | 
 | 188 | +}  | 
 | 189 | +} // namespace torch::executor::native  | 
0 commit comments