Skip to content

Commit 650fa2f

Browse files
Giovanni VersiglioniDenisVieriu97
authored andcommitted
Add portable grid_sampler_2d implementation + tests
1 parent d2d2bf6 commit 650fa2f

File tree

6 files changed

+553
-4
lines changed

6 files changed

+553
-4
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
//
2+
// Copyright (c) 2025 Apple Inc. All rights reserved.
3+
// Provided subject to the LICENSE file in the top level directory.
4+
//
5+
6+
#include <executorch/runtime/kernel/kernel_includes.h>
7+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
8+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
9+
10+
#include <cmath>
11+
#include <cstdint>
12+
#include <algorithm>
13+
14+
#include <c10/util/irange.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
namespace native {
19+
20+
using Tensor = exec_aten::Tensor;
21+
using ScalarType = executorch::aten::ScalarType;
22+
using SizesType = executorch::aten::SizesType;
23+
24+
// Transform normalized coordinates to pixel space
25+
inline float unnormalize_coord(float coord, int64_t size, bool align_corners) {
26+
if (align_corners) {
27+
// -1 and 1 correspond to the centers of the first and last pixels
28+
return ((coord + 1.0f) / 2.0f) * (size - 1);
29+
} else {
30+
// -1 and 1 correspond to the boundary of the image
31+
return ((coord + 1.0f) * size - 1.0f) / 2.0f;
32+
}
33+
}
34+
35+
// Compute source index and interpolation weight
36+
inline void compute_source_index_and_weight(
37+
float coord, int64_t size, bool align_corners,
38+
int64_t& index, float& weight) {
39+
40+
float real_coord = unnormalize_coord(coord, size, align_corners);
41+
index = std::floor(real_coord);
42+
weight = real_coord - index;
43+
}
44+
45+
// Apply reflective padding to handle out-of-bounds coordinates
46+
inline int64_t reflect_coord(int64_t coord, int64_t size) {
47+
if (size <= 1) return 0;
48+
49+
int64_t double_size = 2 * size - 2;
50+
if (double_size <= 0) return 0;
51+
52+
// Handle negative coordinates
53+
int64_t abs_coord = std::abs(coord);
54+
abs_coord = abs_coord % double_size;
55+
if (abs_coord >= size) {
56+
abs_coord = double_size - abs_coord;
57+
}
58+
59+
return abs_coord;
60+
}
61+
62+
// Get pixel value with proper boundary handling based on padding mode
63+
template <typename T>
64+
T get_pixel_value(
65+
const T* input_data,
66+
int64_t n, int64_t c, int64_t h, int64_t w,
67+
int64_t height, int64_t width,
68+
int64_t padding_mode) {
69+
70+
// Handle out-of-bounds coordinates
71+
if (h < 0 || h >= height || w < 0 || w >= width) {
72+
if (padding_mode == 0) { // Zeros
73+
return 0;
74+
} else if (padding_mode == 1) { // Border
75+
h = std::min(std::max(h, static_cast<int64_t>(0)), height - 1);
76+
w = std::min(std::max(w, static_cast<int64_t>(0)), width - 1);
77+
} else if (padding_mode == 2) { // Reflection
78+
h = reflect_coord(h, height);
79+
w = reflect_coord(w, width);
80+
}
81+
}
82+
83+
// Calculate offset using strides for memory access
84+
const int64_t batch_stride = c * height * width;
85+
const int64_t channel_stride = height * width;
86+
const int64_t height_stride = width;
87+
88+
return input_data[n * batch_stride + c * channel_stride + h * height_stride + w];
89+
}
90+
91+
// Process grid sampling with specified input, grid, and modes
92+
template <typename T>
93+
void grid_sampler_2d_impl(
94+
const Tensor& input,
95+
const Tensor& grid,
96+
int64_t interpolation_mode,
97+
int64_t padding_mode,
98+
bool align_corners,
99+
Tensor& out) {
100+
101+
const int64_t N = input.size(0);
102+
const int64_t C = input.size(1);
103+
const int64_t inp_H = input.size(2);
104+
const int64_t inp_W = input.size(3);
105+
const int64_t out_H = grid.size(1);
106+
const int64_t out_W = grid.size(2);
107+
108+
const T* input_data = input.data_ptr<T>();
109+
T* out_data = out.data_ptr<T>();
110+
const float* grid_data = grid.data_ptr<float>();
111+
112+
// Calculate output tensor strides for indexing
113+
const int64_t out_batch_stride = C * out_H * out_W;
114+
const int64_t out_channel_stride = out_H * out_W;
115+
const int64_t out_height_stride = out_W;
116+
117+
// Calculate grid tensor strides based on its actual dimensions
118+
const int64_t grid_batch_stride = out_H * out_W * 2;
119+
const int64_t grid_height_stride = out_W * 2;
120+
const int64_t grid_width_stride = 2;
121+
122+
// Process each output pixel
123+
for (const auto n : c10::irange(N)) {
124+
for (const auto c : c10::irange(C)) {
125+
for (const auto h : c10::irange(out_H)) {
126+
for (const auto w : c10::irange(out_W)) {
127+
128+
// Get grid coordinates (x, y) with stride calculation
129+
const int64_t grid_offset = n * grid_batch_stride + h * grid_height_stride + w * grid_width_stride;
130+
const float x = grid_data[grid_offset];
131+
const float y = grid_data[grid_offset + 1];
132+
133+
// Calculate output index
134+
const int64_t out_idx = n * out_batch_stride + c * out_channel_stride +
135+
h * out_height_stride + w;
136+
137+
// Apply interpolation method
138+
if (interpolation_mode == 0) { // Bilinear
139+
// Calculate corner indices and weights
140+
int64_t ix_nw;
141+
float lambda_x;
142+
compute_source_index_and_weight(x, inp_W, align_corners, ix_nw, lambda_x);
143+
int64_t iy_nw;
144+
float lambda_y;
145+
compute_source_index_and_weight(y, inp_H, align_corners, iy_nw, lambda_y);
146+
147+
// Calculate bilinear weights
148+
float w_nw = (1 - lambda_x) * (1 - lambda_y);
149+
float w_ne = lambda_x * (1 - lambda_y);
150+
float w_sw = (1 - lambda_x) * lambda_y;
151+
float w_se = lambda_x * lambda_y;
152+
153+
// Get corner pixel values with boundary checking
154+
T nw = get_pixel_value(input_data, n, c, iy_nw, ix_nw, inp_H, inp_W, padding_mode);
155+
T ne = get_pixel_value(input_data, n, c, iy_nw, ix_nw + 1, inp_H, inp_W, padding_mode);
156+
T sw = get_pixel_value(input_data, n, c, iy_nw + 1, ix_nw, inp_H, inp_W, padding_mode);
157+
T se = get_pixel_value(input_data, n, c, iy_nw + 1, ix_nw + 1, inp_H, inp_W, padding_mode);
158+
159+
// Perform bilinear interpolation (weighted sum)
160+
out_data[out_idx] = static_cast<T>(nw * w_nw + ne * w_ne + sw * w_sw + se * w_se);
161+
}
162+
else if (interpolation_mode == 1) { // Nearest
163+
// Convert to pixel space and round to nearest pixel
164+
float ix = unnormalize_coord(x, inp_W, align_corners);
165+
float iy = unnormalize_coord(y, inp_H, align_corners);
166+
167+
int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
168+
int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
169+
170+
// Get nearest pixel value
171+
out_data[out_idx] = get_pixel_value(
172+
input_data, n, c, iy_nearest, ix_nearest, inp_H, inp_W, padding_mode);
173+
}
174+
else if (interpolation_mode == 2) { // Bicubic (not implemented)
175+
out_data[out_idx] = 0;
176+
}
177+
}
178+
}
179+
}
180+
}
181+
}
182+
183+
// Main grid_sampler_2d function that validates inputs and dispatches to implementation
184+
Tensor& grid_sampler_2d_out(
185+
KernelRuntimeContext& ctx,
186+
const Tensor& input,
187+
const Tensor& grid,
188+
int64_t interpolation_mode,
189+
int64_t padding_mode,
190+
bool align_corners,
191+
Tensor& out) {
192+
193+
const int64_t N = input.size(0);
194+
const int64_t C = input.size(1);
195+
const int64_t out_H = grid.size(1);
196+
const int64_t out_W = grid.size(2);
197+
198+
// Check for 4D input and grid
199+
ET_KERNEL_CHECK(ctx, (input.dim() == 4), InvalidArgument, out);
200+
ET_KERNEL_CHECK(ctx, (grid.dim() == 4), InvalidArgument, out);
201+
ET_KERNEL_CHECK(ctx, (grid.size(3) == 2), InvalidArgument, out);
202+
203+
// Check that grid is float type
204+
ET_KERNEL_CHECK(ctx, (grid.scalar_type() == ScalarType::Float), InvalidArgument, out);
205+
206+
// Check interpolation mode is valid (0=bilinear, 1=nearest, 2=bicubic)
207+
ET_KERNEL_CHECK(ctx, (interpolation_mode >= 0 && interpolation_mode <= 2), InvalidArgument, out);
208+
209+
// Check padding mode is valid (0=zeros, 1=border, 2=reflection)
210+
ET_KERNEL_CHECK(ctx, (padding_mode >= 0 && padding_mode <= 2), InvalidArgument, out);
211+
212+
// Check for output shape
213+
ET_KERNEL_CHECK(ctx, (out.size(0) == N && out.size(1) == C && out.size(2) == out_H && out.size(3) == out_W),
214+
InvalidArgument, out);
215+
216+
// Dispatch based on input scalar type
217+
ET_SWITCH_REAL_TYPES(input.scalar_type(), ctx, "grid_sampler_2d.out", T, [&]() {
218+
grid_sampler_2d_impl<T>(input, grid, interpolation_mode, padding_mode, align_corners, out);
219+
});
220+
221+
return out;
222+
}
223+
224+
} // namespace native
225+
} // namespace executor
226+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,11 @@
427427
- arg_meta: null
428428
kernel_name: torch::executor::glu_out
429429

430+
- op: grid_sampler_2d.out
431+
kernels:
432+
- arg_meta: null
433+
kernel_name: torch::executor::grid_sampler_2d_out
434+
430435
- op: gt.Scalar_out
431436
kernels:
432437
- arg_meta: null
@@ -985,4 +990,4 @@
985990
- func: dim_order_ops::_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)
986991
kernels:
987992
- arg_meta: null
988-
kernel_name: torch::executor::_to_dim_order_copy_out
993+
kernel_name: torch::executor::_to_dim_order_copy_out

kernels/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ set(all_test_sources
154154
"op_ge_test.cpp"
155155
"op_gelu_test.cpp"
156156
"op_glu_test.cpp"
157+
"op_grid_sampler_2d_test.cpp"
157158
"op_gt_test.cpp"
158159
"op_hardtanh_test.cpp"
159160
"op_index_put_test.cpp"
@@ -349,4 +350,4 @@ if(TARGET quantized_kernels)
349350
PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/include/quantized"
350351
"${CMAKE_CURRENT_BINARY_DIR}/include/portable"
351352
)
352-
endif()
353+
endif()

0 commit comments

Comments
 (0)