@@ -194,8 +194,10 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
194194 y, inp_H, padding_mode, align_corners);
195195
196196 // Get nearest pixel coordinates
197- // Use nearbyint (not round) to match ATen's rounding behavior
198- // see: aten/src/ATen/native/GridSampler.cpp
197+ // Use nearbyint (not round) to match ATen's rounding behavior.
198+ // nearbyint uses the current rounding mode (typically round-to-even),
199+ // which matches PyTorch's (ATen's) behavior. In contrast, round may
200+ // not always respect the rounding mode. See: aten/src/ATen/native/GridSampler.cpp
199201 int64_t ix_nearest = static_cast <int64_t >(std::nearbyint (ix));
200202 int64_t iy_nearest = static_cast <int64_t >(std::nearbyint (iy));
201203
@@ -213,8 +215,8 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
213215 } else {
214216 // For border/reflection padding, clip coordinates after rounding
215217 // Rounding can push coordinates out of bounds even after grid_sampler_compute_source_index
216- ix_nearest = std::max ( static_cast < int64_t >( 0 ), std::min ( ix_nearest, inp_W - 1 ) );
217- iy_nearest = std::max ( static_cast < int64_t >( 0 ), std::min ( iy_nearest, inp_H - 1 ) );
218+ ix_nearest = clip_coordinates ( ix_nearest, inp_W);
219+ iy_nearest = clip_coordinates ( iy_nearest, inp_H);
218220 out_val = in_data
219221 [in_channel_offset + iy_nearest * in.strides ()[2 ] +
220222 ix_nearest * in.strides ()[3 ]];
0 commit comments