1+ /*
2+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3+ * All rights reserved.
4+ *
5+ * This source code is licensed under the BSD-style license found in the
6+ * LICENSE file in the root directory of this source tree.
7+ */
8+
9+ #pragma once
10+
11+ #include < executorch/runtime/core/exec_aten/exec_aten.h>
12+ #include < executorch/runtime/core/exec_aten/util/tensor_util.h>
13+ #include < executorch/runtime/kernel/kernel_includes.h>
14+
15+ namespace torch {
16+ namespace executor {
17+
18+ // Ported from aten/src/ATen/native/GridSampler.h
19+ // note that these need to be in the SAME ORDER as the enum in GridSampler.h
20+ enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
21+ enum class GridSamplerPadding {Zeros, Border, Reflection};
22+
23+
24+ // Ported from aten/src/ATen/native/GridSampler.h
25+ // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
26+ // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
27+ // if align_corners: -1 and +1 get sent to the centers of the corner pixels
28+ // -1 --> 0
29+ // +1 --> (size - 1)
30+ // scale_factor = (size - 1) / 2
31+ // if not align_corners: -1 and +1 get sent to the image edges
32+ // -1 --> -0.5
33+ // +1 --> (size - 1) + 0.5 == size - 0.5
34+ // scale_factor = size / 2
35+ template <typename scalar_t >
36+ inline scalar_t grid_sampler_unnormalize (
37+ scalar_t coord,
38+ int64_t size,
39+ bool align_corners) {
40+ if (align_corners) {
41+ // unnormalize coord from [-1, 1] to [0, size - 1]
42+ return ((coord + 1 ) / 2 ) * (size - 1 );
43+ } else {
44+ // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
45+ return ((coord + 1 ) * size - 1 ) / 2 ;
46+ }
47+ }
48+
49+ // Ported from aten/src/ATen/native/GridSampler.h
50+ // Clips coordinates to between 0 and clip_limit - 1
51+ template <typename scalar_t >
52+ inline scalar_t clip_coordinates (scalar_t in, int64_t clip_limit) {
53+ return std::min (
54+ static_cast <scalar_t >(clip_limit - 1 ),
55+ std::max (in, static_cast <scalar_t >(0 )));
56+ }
57+
58+ // Ported from aten/src/ATen/native/GridSampler.h
59+ // Reflects coordinates until they fall between low and high (inclusive).
60+ // The bounds are passed as twice their value so that half-integer values
61+ // can be represented as ints.
62+ template <typename scalar_t >
63+ inline scalar_t reflect_coordinates (
64+ scalar_t in,
65+ int64_t twice_low,
66+ int64_t twice_high) {
67+ if (twice_low == twice_high) {
68+ return static_cast <scalar_t >(0 );
69+ }
70+ scalar_t min = static_cast <scalar_t >(twice_low) / 2 ;
71+ scalar_t span = static_cast <scalar_t >(twice_high - twice_low) / 2 ;
72+ in = std::fabs (in - min);
73+ // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
74+ scalar_t extra = std::fmod (in, span);
75+ int flips = static_cast <int >(std::floor (in / span));
76+ if (flips % 2 == 0 ) {
77+ return extra + min;
78+ } else {
79+ return span - extra + min;
80+ }
81+ }
82+
83+ // Ported from aten/src/ATen/native/GridSampler.h
84+ // Computes the pixel source index value for a grid coordinate
85+ template <typename scalar_t >
86+ inline scalar_t grid_sampler_compute_source_index (
87+ scalar_t coord,
88+ int64_t size,
89+ GridSamplerPadding padding_mode,
90+ bool align_corners) {
91+ coord = grid_sampler_unnormalize (coord, size, align_corners);
92+ if (padding_mode == GridSamplerPadding::Border) {
93+ // clip coordinates to image borders
94+ coord = clip_coordinates (coord, size);
95+ } else if (padding_mode == GridSamplerPadding::Reflection) {
96+ // reflect coordinates by image borders
97+ if (align_corners) {
98+ coord = reflect_coordinates (coord, 0 , 2 * (size - 1 ));
99+ } else {
100+ coord = reflect_coordinates (coord, -1 , 2 * size - 1 );
101+ }
102+ coord = clip_coordinates (coord, size);
103+ }
104+ return coord;
105+ }
106+
107+ // Ported from aten/src/ATen/native/GridSampler.h
108+ // Check if coordinates are within bounds [0, limit-1]
109+ template <typename scalar_t >
110+ inline bool within_bounds_2d (scalar_t h, scalar_t w, int64_t H, int64_t W) {
111+ return h >= 0 && h < H && w >= 0 && w < W;
112+ }
113+
114+ // Ported from aten/src/ATen/native/UpSample.h
115+ // Cubic convolution function 1 (for points within 1 unit of the point)
116+ template <typename scalar_t >
117+ inline scalar_t cubic_convolution1 (scalar_t x, scalar_t A) {
118+ return ((A + 2 ) * x - (A + 3 )) * x * x + 1 ;
119+ }
120+
121+ // Ported from aten/src/ATen/native/UpSample.h
122+ // Cubic convolution function 2 (for points between 1 and 2 units from the point)
123+ template <typename scalar_t >
124+ inline scalar_t cubic_convolution2 (scalar_t x, scalar_t A) {
125+ return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
126+ }
127+
128+ // Ported from aten/src/ATen/native/UpSample.h
129+ // Computes the 4 cubic interpolation coefficients for a given position t in [0, 1]
130+ template <typename scalar_t >
131+ inline void get_cubic_upsample_coefficients (scalar_t coeffs[4 ], scalar_t t) {
132+ // Standard bicubic interpolation uses alpha = -0.75
133+ scalar_t A = static_cast <scalar_t >(-0.75 );
134+
135+ scalar_t x1 = t;
136+ coeffs[0 ] = cubic_convolution2<scalar_t >(x1 + static_cast <scalar_t >(1.0 ), A);
137+ coeffs[1 ] = cubic_convolution1<scalar_t >(x1, A);
138+
139+ scalar_t x2 = static_cast <scalar_t >(1.0 ) - t;
140+ coeffs[2 ] = cubic_convolution1<scalar_t >(x2, A);
141+ coeffs[3 ] = cubic_convolution2<scalar_t >(x2 + static_cast <scalar_t >(1.0 ), A);
142+ }
143+
144+ // Ported from aten/src/ATen/native/UpSample.h
145+ // Performs 1D cubic interpolation given 4 points and a position t in [0, 1]
146+ template <typename scalar_t >
147+ inline scalar_t cubic_interp1d (
148+ scalar_t x0,
149+ scalar_t x1,
150+ scalar_t x2,
151+ scalar_t x3,
152+ scalar_t t) {
153+ scalar_t coeffs[4 ];
154+ get_cubic_upsample_coefficients<scalar_t >(coeffs, t);
155+
156+ return x0 * coeffs[0 ] + x1 * coeffs[1 ] + x2 * coeffs[2 ] + x3 * coeffs[3 ];
157+ }
158+
159+ // Argument checking for grid_sampler_2d
160+ bool check_grid_sampler_2d_args (
161+ const Tensor& input,
162+ const Tensor& grid,
163+ const Tensor& out);
164+
165+ // Argument checking and output tensor resizing for grid_sampler_2d
166+ Error check_grid_sampler_2d_args_and_resize_out (
167+ const Tensor& input,
168+ const Tensor& grid,
169+ Tensor& out);
170+
171+ } // namespace executor
172+ } // namespace torch
0 commit comments