|
| 1 | +// |
| 2 | +// Copyright (c) 2025 Apple Inc. All rights reserved. |
| 3 | +// Provided subject to the LICENSE file in the top level directory. |
| 4 | +// |
| 5 | + |
| 6 | +#include <executorch/runtime/kernel/kernel_includes.h> |
| 7 | +#include <executorch/kernels/portable/cpu/scalar_utils.h> |
| 8 | +#include <executorch/kernels/portable/cpu/util/elementwise_util.h> |
| 9 | + |
| 10 | +#include <cmath> |
| 11 | +#include <cstdint> |
| 12 | +#include <algorithm> |
| 13 | + |
| 14 | +#include <c10/util/irange.h> |
| 15 | + |
| 16 | +namespace torch { |
| 17 | +namespace executor { |
| 18 | +namespace native { |
| 19 | + |
| 20 | +using Tensor = exec_aten::Tensor; |
| 21 | +using ScalarType = executorch::aten::ScalarType; |
| 22 | +using SizesType = executorch::aten::SizesType; |
| 23 | + |
| 24 | +// Transform normalized coordinates to pixel space |
| 25 | +inline float unnormalize_coord(float coord, int64_t size, bool align_corners) { |
| 26 | + if (align_corners) { |
| 27 | + // -1 and 1 correspond to the centers of the first and last pixels |
| 28 | + return ((coord + 1.0f) / 2.0f) * (size - 1); |
| 29 | + } else { |
| 30 | + // -1 and 1 correspond to the boundary of the image |
| 31 | + return ((coord + 1.0f) * size - 1.0f) / 2.0f; |
| 32 | + } |
| 33 | +} |
| 34 | + |
| 35 | +// Compute source index and interpolation weight |
| 36 | +inline void compute_source_index_and_weight( |
| 37 | + float coord, int64_t size, bool align_corners, |
| 38 | + int64_t& index, float& weight) { |
| 39 | + |
| 40 | + float real_coord = unnormalize_coord(coord, size, align_corners); |
| 41 | + index = std::floor(real_coord); |
| 42 | + weight = real_coord - index; |
| 43 | +} |
| 44 | + |
| 45 | +// Apply reflective padding to handle out-of-bounds coordinates |
| 46 | +inline int64_t reflect_coord(int64_t coord, int64_t size) { |
| 47 | + if (size <= 1) return 0; |
| 48 | + |
| 49 | + int64_t double_size = 2 * size - 2; |
| 50 | + if (double_size <= 0) return 0; |
| 51 | + |
| 52 | + // Handle negative coordinates |
| 53 | + int64_t abs_coord = std::abs(coord); |
| 54 | + abs_coord = abs_coord % double_size; |
| 55 | + if (abs_coord >= size) { |
| 56 | + abs_coord = double_size - abs_coord; |
| 57 | + } |
| 58 | + |
| 59 | + return abs_coord; |
| 60 | +} |
| 61 | + |
| 62 | +// Get pixel value with proper boundary handling based on padding mode |
| 63 | +template <typename T> |
| 64 | +T get_pixel_value( |
| 65 | + const T* input_data, |
| 66 | + int64_t n, int64_t c, int64_t h, int64_t w, |
| 67 | + int64_t height, int64_t width, |
| 68 | + int64_t padding_mode) { |
| 69 | + |
| 70 | + // Handle out-of-bounds coordinates |
| 71 | + if (h < 0 || h >= height || w < 0 || w >= width) { |
| 72 | + if (padding_mode == 0) { // Zeros |
| 73 | + return 0; |
| 74 | + } else if (padding_mode == 1) { // Border |
| 75 | + h = std::min(std::max(h, static_cast<int64_t>(0)), height - 1); |
| 76 | + w = std::min(std::max(w, static_cast<int64_t>(0)), width - 1); |
| 77 | + } else if (padding_mode == 2) { // Reflection |
| 78 | + h = reflect_coord(h, height); |
| 79 | + w = reflect_coord(w, width); |
| 80 | + } |
| 81 | + } |
| 82 | + |
| 83 | + // Calculate offset using strides for memory access |
| 84 | + const int64_t batch_stride = c * height * width; |
| 85 | + const int64_t channel_stride = height * width; |
| 86 | + const int64_t height_stride = width; |
| 87 | + |
| 88 | + return input_data[n * batch_stride + c * channel_stride + h * height_stride + w]; |
| 89 | +} |
| 90 | + |
| 91 | +// Process grid sampling with specified input, grid, and modes |
| 92 | +template <typename T> |
| 93 | +void grid_sampler_2d_impl( |
| 94 | + const Tensor& input, |
| 95 | + const Tensor& grid, |
| 96 | + int64_t interpolation_mode, |
| 97 | + int64_t padding_mode, |
| 98 | + bool align_corners, |
| 99 | + Tensor& out) { |
| 100 | + |
| 101 | + const int64_t N = input.size(0); |
| 102 | + const int64_t C = input.size(1); |
| 103 | + const int64_t inp_H = input.size(2); |
| 104 | + const int64_t inp_W = input.size(3); |
| 105 | + const int64_t out_H = grid.size(1); |
| 106 | + const int64_t out_W = grid.size(2); |
| 107 | + |
| 108 | + const T* input_data = input.data_ptr<T>(); |
| 109 | + T* out_data = out.data_ptr<T>(); |
| 110 | + const float* grid_data = grid.data_ptr<float>(); |
| 111 | + |
| 112 | + // Calculate output tensor strides for indexing |
| 113 | + const int64_t out_batch_stride = C * out_H * out_W; |
| 114 | + const int64_t out_channel_stride = out_H * out_W; |
| 115 | + const int64_t out_height_stride = out_W; |
| 116 | + |
| 117 | + // Calculate grid tensor strides based on its actual dimensions |
| 118 | + const int64_t grid_batch_stride = out_H * out_W * 2; |
| 119 | + const int64_t grid_height_stride = out_W * 2; |
| 120 | + const int64_t grid_width_stride = 2; |
| 121 | + |
| 122 | + // Process each output pixel |
| 123 | + for (const auto n : c10::irange(N)) { |
| 124 | + for (const auto c : c10::irange(C)) { |
| 125 | + for (const auto h : c10::irange(out_H)) { |
| 126 | + for (const auto w : c10::irange(out_W)) { |
| 127 | + |
| 128 | + // Get grid coordinates (x, y) with stride calculation |
| 129 | + const int64_t grid_offset = n * grid_batch_stride + h * grid_height_stride + w * grid_width_stride; |
| 130 | + const float x = grid_data[grid_offset]; |
| 131 | + const float y = grid_data[grid_offset + 1]; |
| 132 | + |
| 133 | + // Calculate output index |
| 134 | + const int64_t out_idx = n * out_batch_stride + c * out_channel_stride + |
| 135 | + h * out_height_stride + w; |
| 136 | + |
| 137 | + // Apply interpolation method |
| 138 | + if (interpolation_mode == 0) { // Bilinear |
| 139 | + // Calculate corner indices and weights |
| 140 | + int64_t ix_nw; |
| 141 | + float lambda_x; |
| 142 | + compute_source_index_and_weight(x, inp_W, align_corners, ix_nw, lambda_x); |
| 143 | + int64_t iy_nw; |
| 144 | + float lambda_y; |
| 145 | + compute_source_index_and_weight(y, inp_H, align_corners, iy_nw, lambda_y); |
| 146 | + |
| 147 | + // Calculate bilinear weights |
| 148 | + float w_nw = (1 - lambda_x) * (1 - lambda_y); |
| 149 | + float w_ne = lambda_x * (1 - lambda_y); |
| 150 | + float w_sw = (1 - lambda_x) * lambda_y; |
| 151 | + float w_se = lambda_x * lambda_y; |
| 152 | + |
| 153 | + // Get corner pixel values with boundary checking |
| 154 | + T nw = get_pixel_value(input_data, n, c, iy_nw, ix_nw, inp_H, inp_W, padding_mode); |
| 155 | + T ne = get_pixel_value(input_data, n, c, iy_nw, ix_nw + 1, inp_H, inp_W, padding_mode); |
| 156 | + T sw = get_pixel_value(input_data, n, c, iy_nw + 1, ix_nw, inp_H, inp_W, padding_mode); |
| 157 | + T se = get_pixel_value(input_data, n, c, iy_nw + 1, ix_nw + 1, inp_H, inp_W, padding_mode); |
| 158 | + |
| 159 | + // Perform bilinear interpolation (weighted sum) |
| 160 | + out_data[out_idx] = static_cast<T>(nw * w_nw + ne * w_ne + sw * w_sw + se * w_se); |
| 161 | + } |
| 162 | + else if (interpolation_mode == 1) { // Nearest |
| 163 | + // Convert to pixel space and round to nearest pixel |
| 164 | + float ix = unnormalize_coord(x, inp_W, align_corners); |
| 165 | + float iy = unnormalize_coord(y, inp_H, align_corners); |
| 166 | + |
| 167 | + int64_t ix_nearest = static_cast<int64_t>(std::round(ix)); |
| 168 | + int64_t iy_nearest = static_cast<int64_t>(std::round(iy)); |
| 169 | + |
| 170 | + // Get nearest pixel value |
| 171 | + out_data[out_idx] = get_pixel_value( |
| 172 | + input_data, n, c, iy_nearest, ix_nearest, inp_H, inp_W, padding_mode); |
| 173 | + } |
| 174 | + else if (interpolation_mode == 2) { // Bicubic (not implemented) |
| 175 | + out_data[out_idx] = 0; |
| 176 | + } |
| 177 | + } |
| 178 | + } |
| 179 | + } |
| 180 | + } |
| 181 | +} |
| 182 | + |
| 183 | +// Main grid_sampler_2d function that validates inputs and dispatches to implementation |
| 184 | +Tensor& grid_sampler_2d_out( |
| 185 | + KernelRuntimeContext& ctx, |
| 186 | + const Tensor& input, |
| 187 | + const Tensor& grid, |
| 188 | + int64_t interpolation_mode, |
| 189 | + int64_t padding_mode, |
| 190 | + bool align_corners, |
| 191 | + Tensor& out) { |
| 192 | + |
| 193 | + const int64_t N = input.size(0); |
| 194 | + const int64_t C = input.size(1); |
| 195 | + const int64_t out_H = grid.size(1); |
| 196 | + const int64_t out_W = grid.size(2); |
| 197 | + |
| 198 | + // Check for 4D input and grid |
| 199 | + ET_KERNEL_CHECK(ctx, (input.dim() == 4), InvalidArgument, out); |
| 200 | + ET_KERNEL_CHECK(ctx, (grid.dim() == 4), InvalidArgument, out); |
| 201 | + ET_KERNEL_CHECK(ctx, (grid.size(3) == 2), InvalidArgument, out); |
| 202 | + |
| 203 | + // Check that grid is float type |
| 204 | + ET_KERNEL_CHECK(ctx, (grid.scalar_type() == ScalarType::Float), InvalidArgument, out); |
| 205 | + |
| 206 | + // Check interpolation mode is valid (0=bilinear, 1=nearest, 2=bicubic) |
| 207 | + ET_KERNEL_CHECK(ctx, (interpolation_mode >= 0 && interpolation_mode <= 2), InvalidArgument, out); |
| 208 | + |
| 209 | + // Check padding mode is valid (0=zeros, 1=border, 2=reflection) |
| 210 | + ET_KERNEL_CHECK(ctx, (padding_mode >= 0 && padding_mode <= 2), InvalidArgument, out); |
| 211 | + |
| 212 | + // Check for output shape |
| 213 | + ET_KERNEL_CHECK(ctx, (out.size(0) == N && out.size(1) == C && out.size(2) == out_H && out.size(3) == out_W), |
| 214 | + InvalidArgument, out); |
| 215 | + |
| 216 | + // Dispatch based on input scalar type |
| 217 | + ET_SWITCH_REAL_TYPES(input.scalar_type(), ctx, "grid_sampler_2d.out", T, [&]() { |
| 218 | + grid_sampler_2d_impl<T>(input, grid, interpolation_mode, padding_mode, align_corners, out); |
| 219 | + }); |
| 220 | + |
| 221 | + return out; |
| 222 | +} |
| 223 | + |
| 224 | +} // namespace native |
| 225 | +} // namespace executor |
| 226 | +} // namespace torch |
0 commit comments