Skip to content

Commit f9a09b4

Browse files
committed
Address copilot comments
1 parent 0ec0821 commit f9a09b4

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

kernels/portable/cpu/op_grid_sampler_2d.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,10 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
194194
y, inp_H, padding_mode, align_corners);
195195

196196
// Get nearest pixel coordinates
197-
// Use nearbyint (not round) to match ATen's rounding behavior
198-
// see: aten/src/ATen/native/GridSampler.cpp
197+
// Use nearbyint (not round) to match ATen's rounding behavior.
198+
// nearbyint uses the current rounding mode (typically round-to-even),
199+
// which matches PyTorch's (ATen's) behavior. In contrast, round may
200+
// not always respect the rounding mode. See: aten/src/ATen/native/GridSampler.cpp
199201
int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
200202
int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
201203

@@ -213,8 +215,8 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
213215
} else {
214216
// For border/reflection padding, clip coordinates after rounding
215217
// Rounding can push coordinates out of bounds even after grid_sampler_compute_source_index
216-
ix_nearest = std::max(static_cast<int64_t>(0), std::min(ix_nearest, inp_W - 1));
217-
iy_nearest = std::max(static_cast<int64_t>(0), std::min(iy_nearest, inp_H - 1));
218+
ix_nearest = clip_coordinates(ix_nearest, inp_W);
219+
iy_nearest = clip_coordinates(iy_nearest, inp_H);
218220
out_val = in_data
219221
[in_channel_offset + iy_nearest * in.strides()[2] +
220222
ix_nearest * in.strides()[3]];

kernels/portable/cpu/util/grid_sampler_2d_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ namespace executor {
1717

1818
// Ported from aten/src/ATen/native/GridSampler.h
1919
// note that these need to be in the SAME ORDER as the enum in GridSampler.h
20+
// as they are mapped to integer values (0, 1, 2) in this order
2021
enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
2122
enum class GridSamplerPadding {Zeros, Border, Reflection};
2223

kernels/portable/test/op_grid_sampler_2d_test.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ TEST_F(OpGridSampler2dTest, BicubicSimple) {
200200
out);
201201

202202
// Bicubic at center should be close to 8.5 (average of middle pixels)
203+
// Note: The tolerance of 0.5 is intentionally large because the expected value (8.5)
204+
// is a rough estimate (average of the middle pixels), not the exact bicubic interpolation result.
205+
// Bicubic interpolation can produce values that differ from this average due to its mathematical properties.
203206
const auto expected = tf.make({1, 1, 1, 1}, {8.5});
204207
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 0, 0.5);
205208
}

kernels/portable/test/test_grid_sampler_2d_executorch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,8 @@ def run_executorch_test(
8787
executorch_output = fwd_method.execute((input_tensor, grid))[0]
8888

8989
# Compare results
90-
self.assertEqual(
91-
executorch_output.shape,
92-
pytorch_output.shape,
90+
self.assertTrue(
91+
executorch_output.shape == pytorch_output.shape,
9392
msg=f"Shape mismatch: ET={executorch_output.shape} vs PT={pytorch_output.shape}",
9493
)
9594

0 commit comments

Comments
 (0)