@@ -92,46 +92,46 @@ void grid_sample_2d_bilinear_kernel_impl_nchw(
9292 // For zeros padding, only sample if within bounds
9393 if (within_bounds_2d (iy_nw, ix_nw, inp_H, inp_W)) {
9494 out_val += in_data
95- [in_channel_offset + iy_nw * in.strides ()[2 ] +
96- ix_nw * in.strides ()[3 ]] *
95+ [in_channel_offset + iy_nw * in.strides ()[2 ] +
96+ ix_nw * in.strides ()[3 ]] *
9797 nw_weight;
9898 }
9999 if (within_bounds_2d (iy_ne, ix_ne, inp_H, inp_W)) {
100100 out_val += in_data
101- [in_channel_offset + iy_ne * in.strides ()[2 ] +
102- ix_ne * in.strides ()[3 ]] *
101+ [in_channel_offset + iy_ne * in.strides ()[2 ] +
102+ ix_ne * in.strides ()[3 ]] *
103103 ne_weight;
104104 }
105105 if (within_bounds_2d (iy_sw, ix_sw, inp_H, inp_W)) {
106106 out_val += in_data
107- [in_channel_offset + iy_sw * in.strides ()[2 ] +
108- ix_sw * in.strides ()[3 ]] *
107+ [in_channel_offset + iy_sw * in.strides ()[2 ] +
108+ ix_sw * in.strides ()[3 ]] *
109109 sw_weight;
110110 }
111111 if (within_bounds_2d (iy_se, ix_se, inp_H, inp_W)) {
112112 out_val += in_data
113- [in_channel_offset + iy_se * in.strides ()[2 ] +
114- ix_se * in.strides ()[3 ]] *
113+ [in_channel_offset + iy_se * in.strides ()[2 ] +
114+ ix_se * in.strides ()[3 ]] *
115115 se_weight;
116116 }
117117 } else {
118118 // For border/reflection padding, coordinates are already clipped
119119 out_val = in_data
120120 [in_channel_offset + iy_nw * in.strides ()[2 ] +
121121 ix_nw * in.strides ()[3 ]] *
122- nw_weight +
123- in_data
124- [in_channel_offset + iy_ne * in.strides ()[2 ] +
125- ix_ne * in.strides ()[3 ]] *
126- ne_weight +
127- in_data
128- [in_channel_offset + iy_sw * in.strides ()[2 ] +
129- ix_sw * in.strides ()[3 ]] *
130- sw_weight +
131- in_data
132- [in_channel_offset + iy_se * in.strides ()[2 ] +
133- ix_se * in.strides ()[3 ]] *
134- se_weight;
122+ nw_weight +
123+ in_data
124+ [in_channel_offset + iy_ne * in.strides ()[2 ] +
125+ ix_ne * in.strides ()[3 ]] *
126+ ne_weight +
127+ in_data
128+ [in_channel_offset + iy_sw * in.strides ()[2 ] +
129+ ix_sw * in.strides ()[3 ]] *
130+ sw_weight +
131+ in_data
132+ [in_channel_offset + iy_se * in.strides ()[2 ] +
133+ ix_se * in.strides ()[3 ]] *
134+ se_weight;
135135 }
136136
137137 // Write output in NCHW order
@@ -197,7 +197,8 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
197197 // Use nearbyint (not round) to match ATen's rounding behavior.
198198 // nearbyint uses the current rounding mode (typically round-to-even),
199199 // which matches PyTorch's (ATen's) behavior. In contrast, round may
200- // not always respect the rounding mode. See: aten/src/ATen/native/GridSampler.cpp
200+ // not always respect the rounding mode. See:
201+ // aten/src/ATen/native/GridSampler.cpp
201202 int64_t ix_nearest = static_cast <int64_t >(std::nearbyint (ix));
202203 int64_t iy_nearest = static_cast <int64_t >(std::nearbyint (iy));
203204
@@ -214,7 +215,8 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
214215 }
215216 } else {
216217 // For border/reflection padding, clip coordinates after rounding
217- // Rounding can push coordinates out of bounds even after grid_sampler_compute_source_index
218+ // Rounding can push coordinates out of bounds even after
219+ // grid_sampler_compute_source_index
218220 ix_nearest = clip_coordinates (ix_nearest, inp_W);
219221 iy_nearest = clip_coordinates (iy_nearest, inp_H);
220222 out_val = in_data
@@ -232,7 +234,6 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
232234 }
233235}
234236
235-
236237template <typename CTYPE>
237238void grid_sample_2d_bicubic_kernel_impl_nchw (
238239 const Tensor& in,
@@ -277,8 +278,9 @@ void grid_sample_2d_bicubic_kernel_impl_nchw(
277278 const CTYPE y = grid_data[grid_idx + grid.strides ()[3 ]];
278279
279280 // Compute source coordinates in pixel space
280- // For bicubic, we need raw unnormalized coordinates without padding applied
281- // Padding is applied later when fetching individual pixels from the 4x4 neighborhood
281+ // For bicubic, we need raw unnormalized coordinates without padding
282+ // applied Padding is applied later when fetching individual pixels
283+ // from the 4x4 neighborhood
282284 CTYPE ix = grid_sampler_unnormalize (x, inp_W, align_corners);
283285 CTYPE iy = grid_sampler_unnormalize (y, inp_H, align_corners);
284286
@@ -309,12 +311,10 @@ void grid_sample_2d_bicubic_kernel_impl_nchw(
309311 return static_cast <CTYPE>(0 );
310312 } else if (padding_mode == GridSamplerPadding::Border) {
311313 // For border padding, clip coordinates to valid range
312- int64_t iy_safe = std::max (
313- static_cast <int64_t >(0 ),
314- std::min (iy, inp_H - 1 ));
315- int64_t ix_safe = std::max (
316- static_cast <int64_t >(0 ),
317- std::min (ix, inp_W - 1 ));
314+ int64_t iy_safe =
315+ std::max (static_cast <int64_t >(0 ), std::min (iy, inp_H - 1 ));
316+ int64_t ix_safe =
317+ std::max (static_cast <int64_t >(0 ), std::min (ix, inp_W - 1 ));
318318 return in_data
319319 [in_channel_offset + iy_safe * in.strides ()[2 ] +
320320 ix_safe * in.strides ()[3 ]];
@@ -324,16 +324,22 @@ void grid_sample_2d_bicubic_kernel_impl_nchw(
324324 CTYPE ix_reflected = static_cast <CTYPE>(ix);
325325
326326 if (align_corners) {
327- iy_reflected = reflect_coordinates (iy_reflected, 0 , 2 * (inp_H - 1 ));
328- ix_reflected = reflect_coordinates (ix_reflected, 0 , 2 * (inp_W - 1 ));
327+ iy_reflected =
328+ reflect_coordinates (iy_reflected, 0 , 2 * (inp_H - 1 ));
329+ ix_reflected =
330+ reflect_coordinates (ix_reflected, 0 , 2 * (inp_W - 1 ));
329331 } else {
330- iy_reflected = reflect_coordinates (iy_reflected, -1 , 2 * inp_H - 1 );
331- ix_reflected = reflect_coordinates (ix_reflected, -1 , 2 * inp_W - 1 );
332+ iy_reflected =
333+ reflect_coordinates (iy_reflected, -1 , 2 * inp_H - 1 );
334+ ix_reflected =
335+ reflect_coordinates (ix_reflected, -1 , 2 * inp_W - 1 );
332336 }
333337
334338 // Clip to ensure we're in bounds (reflection + clip for safety)
335- int64_t iy_safe = static_cast <int64_t >(clip_coordinates (iy_reflected, inp_H));
336- int64_t ix_safe = static_cast <int64_t >(clip_coordinates (ix_reflected, inp_W));
339+ int64_t iy_safe =
340+ static_cast <int64_t >(clip_coordinates (iy_reflected, inp_H));
341+ int64_t ix_safe =
342+ static_cast <int64_t >(clip_coordinates (ix_reflected, inp_W));
337343
338344 return in_data
339345 [in_channel_offset + iy_safe * in.strides ()[2 ] +
@@ -375,7 +381,11 @@ void grid_sample_2d_bicubic_kernel_impl_nchw(
375381
376382 // Interpolate in y-direction
377383 CTYPE out_val = cubic_interp1d (
378- coefficients[0 ], coefficients[1 ], coefficients[2 ], coefficients[3 ], ty);
384+ coefficients[0 ],
385+ coefficients[1 ],
386+ coefficients[2 ],
387+ coefficients[3 ],
388+ ty);
379389
380390 // Write output in NCHW order
381391 const int64_t out_idx =
@@ -409,7 +419,8 @@ Tensor& grid_sampler_2d_out(
409419 " Failed to validate arguments and resize output tensor" );
410420
411421 // Convert integer mode parameters to enums
412- GridSamplerInterpolation mode = static_cast <GridSamplerInterpolation>(interpolation_mode);
422+ GridSamplerInterpolation mode =
423+ static_cast <GridSamplerInterpolation>(interpolation_mode);
413424 GridSamplerPadding padding = static_cast <GridSamplerPadding>(padding_mode);
414425
415426 // Validate mode and padding values
@@ -454,7 +465,6 @@ Tensor& grid_sampler_2d_out(
454465// NOLINTEND(facebook-hte-ConstantArgumentPassByValue,
455466// facebook-hte-ParameterMightThrowOnCopy)
456467
457-
458468} // namespace native
459469} // namespace executor
460470} // namespace torch
0 commit comments