Skip to content

Commit 7e23c70

Browse files
John GibsonJohn Gibson
authored andcommitted
Working implementation through executorch test
1 parent 12d17ef commit 7e23c70

File tree

9 files changed

+1475
-0
lines changed

9 files changed

+1475
-0
lines changed

kernels/portable/cpu/op_grid_sampler_2d.cpp

Lines changed: 458 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/grid_sampler_2d_util.h>
10+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
15+
bool check_grid_sampler_2d_args(
16+
const Tensor& input,
17+
const Tensor& grid,
18+
const Tensor& out) {
19+
// Input must be 4D (N, C, H, W)
20+
ET_LOG_AND_RETURN_IF_FALSE(input.dim() == 4);
21+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(input));
22+
23+
// Grid must be 4D (N, H_out, W_out, 2)
24+
ET_LOG_AND_RETURN_IF_FALSE(grid.dim() == 4);
25+
ET_LOG_AND_RETURN_IF_FALSE(grid.size(3) == 2);
26+
27+
// Output must be 4D (N, C, H_out, W_out)
28+
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == 4);
29+
30+
// Batch sizes must match
31+
ET_LOG_AND_RETURN_IF_FALSE(input.size(0) == grid.size(0));
32+
ET_LOG_AND_RETURN_IF_FALSE(input.size(0) == out.size(0));
33+
34+
// Channel dimension must match between input and output
35+
ET_LOG_AND_RETURN_IF_FALSE(input.size(1) == out.size(1));
36+
37+
// Output spatial dimensions must match grid dimensions
38+
ET_LOG_AND_RETURN_IF_FALSE(out.size(2) == grid.size(1));
39+
ET_LOG_AND_RETURN_IF_FALSE(out.size(3) == grid.size(2));
40+
41+
// Input and output must have same dtype
42+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, out));
43+
44+
// Grid and input must have same dtype
45+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, grid));
46+
47+
// Output must have same dim order as input
48+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dim_order(input, out));
49+
50+
return true;
51+
}
52+
53+
Error check_grid_sampler_2d_args_and_resize_out(
54+
const Tensor& input,
55+
const Tensor& grid,
56+
Tensor& out) {
57+
// Input must be 4D (N, C, H, W)
58+
ET_CHECK_OR_RETURN_ERROR(
59+
input.dim() == 4,
60+
InvalidArgument,
61+
"Input must be 4D, got %zu dimensions",
62+
static_cast<size_t>(input.dim()));
63+
64+
ET_CHECK_OR_RETURN_ERROR(
65+
tensor_is_default_dim_order(input),
66+
InvalidArgument,
67+
"Input must be in NCHW format");
68+
69+
// Grid must be 4D (N, H_out, W_out, 2)
70+
ET_CHECK_OR_RETURN_ERROR(
71+
grid.dim() == 4,
72+
InvalidArgument,
73+
"Grid must be 4D, got %zu dimensions",
74+
static_cast<size_t>(grid.dim()));
75+
76+
ET_CHECK_OR_RETURN_ERROR(
77+
grid.size(3) == 2,
78+
InvalidArgument,
79+
"Grid last dimension must be 2, got %ld",
80+
static_cast<long>(grid.size(3)));
81+
82+
// Batch sizes must match
83+
ET_CHECK_OR_RETURN_ERROR(
84+
input.size(0) == grid.size(0),
85+
InvalidArgument,
86+
"Input and grid batch sizes must match, got input=%ld, grid=%ld",
87+
static_cast<long>(input.size(0)),
88+
static_cast<long>(grid.size(0)));
89+
90+
// Input and grid must have same dtype
91+
ET_CHECK_OR_RETURN_ERROR(
92+
tensors_have_same_dtype(input, grid),
93+
InvalidArgument,
94+
"Input and grid must have same dtype");
95+
96+
// Resize output tensor to [N, C, H_out, W_out]
97+
std::array<exec_aten::SizesType, 4> out_sizes = {
98+
static_cast<exec_aten::SizesType>(input.size(0)),
99+
static_cast<exec_aten::SizesType>(input.size(1)),
100+
static_cast<exec_aten::SizesType>(grid.size(1)),
101+
static_cast<exec_aten::SizesType>(grid.size(2))};
102+
103+
Error err = resize_tensor(out, {out_sizes.data(), 4});
104+
ET_CHECK_OR_RETURN_ERROR(
105+
err == Error::Ok,
106+
InvalidArgument,
107+
"Failed to resize output tensor");
108+
109+
return Error::Ok;
110+
}
111+
112+
} // namespace executor
113+
} // namespace torch
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
// Ported from aten/src/ATen/native/GridSampler.h
19+
// note that these need to be in the SAME ORDER as the enum in GridSampler.h
20+
enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
21+
enum class GridSamplerPadding {Zeros, Border, Reflection};
22+
23+
24+
// Ported from aten/src/ATen/native/GridSampler.h
25+
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
26+
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
27+
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
28+
// -1 --> 0
29+
// +1 --> (size - 1)
30+
// scale_factor = (size - 1) / 2
31+
// if not align_corners: -1 and +1 get sent to the image edges
32+
// -1 --> -0.5
33+
// +1 --> (size - 1) + 0.5 == size - 0.5
34+
// scale_factor = size / 2
35+
template <typename scalar_t>
36+
inline scalar_t grid_sampler_unnormalize(
37+
scalar_t coord,
38+
int64_t size,
39+
bool align_corners) {
40+
if (align_corners) {
41+
// unnormalize coord from [-1, 1] to [0, size - 1]
42+
return ((coord + 1) / 2) * (size - 1);
43+
} else {
44+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
45+
return ((coord + 1) * size - 1) / 2;
46+
}
47+
}
48+
49+
// Ported from aten/src/ATen/native/GridSampler.h
50+
// Clips coordinates to between 0 and clip_limit - 1
51+
template <typename scalar_t>
52+
inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
53+
return std::min(
54+
static_cast<scalar_t>(clip_limit - 1),
55+
std::max(in, static_cast<scalar_t>(0)));
56+
}
57+
58+
// Ported from aten/src/ATen/native/GridSampler.h
59+
// Reflects coordinates until they fall between low and high (inclusive).
60+
// The bounds are passed as twice their value so that half-integer values
61+
// can be represented as ints.
62+
template <typename scalar_t>
63+
inline scalar_t reflect_coordinates(
64+
scalar_t in,
65+
int64_t twice_low,
66+
int64_t twice_high) {
67+
if (twice_low == twice_high) {
68+
return static_cast<scalar_t>(0);
69+
}
70+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
71+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
72+
in = std::fabs(in - min);
73+
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
74+
scalar_t extra = std::fmod(in, span);
75+
int flips = static_cast<int>(std::floor(in / span));
76+
if (flips % 2 == 0) {
77+
return extra + min;
78+
} else {
79+
return span - extra + min;
80+
}
81+
}
82+
83+
// Ported from aten/src/ATen/native/GridSampler.h
84+
// Computes the pixel source index value for a grid coordinate
85+
template <typename scalar_t>
86+
inline scalar_t grid_sampler_compute_source_index(
87+
scalar_t coord,
88+
int64_t size,
89+
GridSamplerPadding padding_mode,
90+
bool align_corners) {
91+
coord = grid_sampler_unnormalize(coord, size, align_corners);
92+
if (padding_mode == GridSamplerPadding::Border) {
93+
// clip coordinates to image borders
94+
coord = clip_coordinates(coord, size);
95+
} else if (padding_mode == GridSamplerPadding::Reflection) {
96+
// reflect coordinates by image borders
97+
if (align_corners) {
98+
coord = reflect_coordinates(coord, 0, 2 * (size - 1));
99+
} else {
100+
coord = reflect_coordinates(coord, -1, 2 * size - 1);
101+
}
102+
coord = clip_coordinates(coord, size);
103+
}
104+
return coord;
105+
}
106+
107+
// Ported from aten/src/ATen/native/GridSampler.h
108+
// Check if coordinates are within bounds [0, limit-1]
109+
template <typename scalar_t>
110+
inline bool within_bounds_2d(scalar_t h, scalar_t w, int64_t H, int64_t W) {
111+
return h >= 0 && h < H && w >= 0 && w < W;
112+
}
113+
114+
// Ported from aten/src/ATen/native/UpSample.h
115+
// Cubic convolution function 1 (for points within 1 unit of the point)
116+
template <typename scalar_t>
117+
inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
118+
return ((A + 2) * x - (A + 3)) * x * x + 1;
119+
}
120+
121+
// Ported from aten/src/ATen/native/UpSample.h
122+
// Cubic convolution function 2 (for points between 1 and 2 units from the point)
123+
template <typename scalar_t>
124+
inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
125+
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
126+
}
127+
128+
// Ported from aten/src/ATen/native/UpSample.h
129+
// Computes the 4 cubic interpolation coefficients for a given position t in [0, 1]
130+
template <typename scalar_t>
131+
inline void get_cubic_upsample_coefficients(scalar_t coeffs[4], scalar_t t) {
132+
// Standard bicubic interpolation uses alpha = -0.75
133+
scalar_t A = static_cast<scalar_t>(-0.75);
134+
135+
scalar_t x1 = t;
136+
coeffs[0] = cubic_convolution2<scalar_t>(x1 + static_cast<scalar_t>(1.0), A);
137+
coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
138+
139+
scalar_t x2 = static_cast<scalar_t>(1.0) - t;
140+
coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
141+
coeffs[3] = cubic_convolution2<scalar_t>(x2 + static_cast<scalar_t>(1.0), A);
142+
}
143+
144+
// Ported from aten/src/ATen/native/UpSample.h
145+
// Performs 1D cubic interpolation given 4 points and a position t in [0, 1]
146+
template <typename scalar_t>
147+
inline scalar_t cubic_interp1d(
148+
scalar_t x0,
149+
scalar_t x1,
150+
scalar_t x2,
151+
scalar_t x3,
152+
scalar_t t) {
153+
scalar_t coeffs[4];
154+
get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
155+
156+
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
157+
}
158+
159+
// Argument checking for grid_sampler_2d
160+
bool check_grid_sampler_2d_args(
161+
const Tensor& input,
162+
const Tensor& grid,
163+
const Tensor& out);
164+
165+
// Argument checking and output tensor resizing for grid_sampler_2d
166+
Error check_grid_sampler_2d_args_and_resize_out(
167+
const Tensor& input,
168+
const Tensor& grid,
169+
Tensor& out);
170+
171+
} // namespace executor
172+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,11 @@
427427
- arg_meta: null
428428
kernel_name: torch::executor::glu_out
429429

430+
- op: grid_sampler_2d.out
431+
kernels:
432+
- arg_meta: null
433+
kernel_name: torch::executor::grid_sampler_2d_out
434+
430435
- op: gt.Scalar_out
431436
kernels:
432437
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ runtime.cxx_library(
1919
],
2020
deps = [
2121
"//executorch/extension/aten_util:aten_bridge",
22+
"//executorch/kernels/portable/cpu:op_grid_sampler_2d",
2223
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
2324
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d_aa",
2425
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",

0 commit comments

Comments
 (0)