Skip to content

Commit c0a3db9

Browse files
committed
Fix test compilation and runtime error.
1 parent 9085caa commit c0a3db9

File tree

7 files changed

+245
-176
lines changed

7 files changed

+245
-176
lines changed

kernels/portable/cpu/op_grid_sampler_2d.cpp

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -92,46 +92,46 @@ void grid_sample_2d_bilinear_kernel_impl_nchw(
9292
// For zeros padding, only sample if within bounds
9393
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
9494
out_val += in_data
95-
[in_channel_offset + iy_nw * in.strides()[2] +
96-
ix_nw * in.strides()[3]] *
95+
[in_channel_offset + iy_nw * in.strides()[2] +
96+
ix_nw * in.strides()[3]] *
9797
nw_weight;
9898
}
9999
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
100100
out_val += in_data
101-
[in_channel_offset + iy_ne * in.strides()[2] +
102-
ix_ne * in.strides()[3]] *
101+
[in_channel_offset + iy_ne * in.strides()[2] +
102+
ix_ne * in.strides()[3]] *
103103
ne_weight;
104104
}
105105
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
106106
out_val += in_data
107-
[in_channel_offset + iy_sw * in.strides()[2] +
108-
ix_sw * in.strides()[3]] *
107+
[in_channel_offset + iy_sw * in.strides()[2] +
108+
ix_sw * in.strides()[3]] *
109109
sw_weight;
110110
}
111111
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
112112
out_val += in_data
113-
[in_channel_offset + iy_se * in.strides()[2] +
114-
ix_se * in.strides()[3]] *
113+
[in_channel_offset + iy_se * in.strides()[2] +
114+
ix_se * in.strides()[3]] *
115115
se_weight;
116116
}
117117
} else {
118118
// For border/reflection padding, coordinates are already clipped
119119
out_val = in_data
120120
[in_channel_offset + iy_nw * in.strides()[2] +
121121
ix_nw * in.strides()[3]] *
122-
nw_weight +
123-
in_data
124-
[in_channel_offset + iy_ne * in.strides()[2] +
125-
ix_ne * in.strides()[3]] *
126-
ne_weight +
127-
in_data
128-
[in_channel_offset + iy_sw * in.strides()[2] +
129-
ix_sw * in.strides()[3]] *
130-
sw_weight +
131-
in_data
132-
[in_channel_offset + iy_se * in.strides()[2] +
133-
ix_se * in.strides()[3]] *
134-
se_weight;
122+
nw_weight +
123+
in_data
124+
[in_channel_offset + iy_ne * in.strides()[2] +
125+
ix_ne * in.strides()[3]] *
126+
ne_weight +
127+
in_data
128+
[in_channel_offset + iy_sw * in.strides()[2] +
129+
ix_sw * in.strides()[3]] *
130+
sw_weight +
131+
in_data
132+
[in_channel_offset + iy_se * in.strides()[2] +
133+
ix_se * in.strides()[3]] *
134+
se_weight;
135135
}
136136

137137
// Write output in NCHW order
@@ -197,7 +197,8 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
197197
// Use nearbyint (not round) to match ATen's rounding behavior.
198198
// nearbyint uses the current rounding mode (typically round-to-even),
199199
// which matches PyTorch's (ATen's) behavior. In contrast, round may
200-
// not always respect the rounding mode. See: aten/src/ATen/native/GridSampler.cpp
200+
// not always respect the rounding mode. See:
201+
// aten/src/ATen/native/GridSampler.cpp
201202
int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
202203
int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
203204

@@ -214,7 +215,8 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
214215
}
215216
} else {
216217
// For border/reflection padding, clip coordinates after rounding
217-
// Rounding can push coordinates out of bounds even after grid_sampler_compute_source_index
218+
// Rounding can push coordinates out of bounds even after
219+
// grid_sampler_compute_source_index
218220
ix_nearest = clip_coordinates(ix_nearest, inp_W);
219221
iy_nearest = clip_coordinates(iy_nearest, inp_H);
220222
out_val = in_data
@@ -232,7 +234,6 @@ void grid_sample_2d_nearest_kernel_impl_nchw(
232234
}
233235
}
234236

235-
236237
template <typename CTYPE>
237238
void grid_sample_2d_bicubic_kernel_impl_nchw(
238239
const Tensor& in,
@@ -277,8 +278,9 @@ void grid_sample_2d_bicubic_kernel_impl_nchw(
277278
const CTYPE y = grid_data[grid_idx + grid.strides()[3]];
278279

279280
// Compute source coordinates in pixel space
280-
// For bicubic, we need raw unnormalized coordinates without padding applied
281-
// Padding is applied later when fetching individual pixels from the 4x4 neighborhood
281+
// For bicubic, we need raw unnormalized coordinates without padding
282+
// applied Padding is applied later when fetching individual pixels
283+
// from the 4x4 neighborhood
282284
CTYPE ix = grid_sampler_unnormalize(x, inp_W, align_corners);
283285
CTYPE iy = grid_sampler_unnormalize(y, inp_H, align_corners);
284286

@@ -309,12 +311,10 @@ void grid_sample_2d_bicubic_kernel_impl_nchw(
309311
return static_cast<CTYPE>(0);
310312
} else if (padding_mode == GridSamplerPadding::Border) {
311313
// For border padding, clip coordinates to valid range
312-
int64_t iy_safe = std::max(
313-
static_cast<int64_t>(0),
314-
std::min(iy, inp_H - 1));
315-
int64_t ix_safe = std::max(
316-
static_cast<int64_t>(0),
317-
std::min(ix, inp_W - 1));
314+
int64_t iy_safe =
315+
std::max(static_cast<int64_t>(0), std::min(iy, inp_H - 1));
316+
int64_t ix_safe =
317+
std::max(static_cast<int64_t>(0), std::min(ix, inp_W - 1));
318318
return in_data
319319
[in_channel_offset + iy_safe * in.strides()[2] +
320320
ix_safe * in.strides()[3]];
@@ -324,16 +324,22 @@ void grid_sample_2d_bicubic_kernel_impl_nchw(
324324
CTYPE ix_reflected = static_cast<CTYPE>(ix);
325325

326326
if (align_corners) {
327-
iy_reflected = reflect_coordinates(iy_reflected, 0, 2 * (inp_H - 1));
328-
ix_reflected = reflect_coordinates(ix_reflected, 0, 2 * (inp_W - 1));
327+
iy_reflected =
328+
reflect_coordinates(iy_reflected, 0, 2 * (inp_H - 1));
329+
ix_reflected =
330+
reflect_coordinates(ix_reflected, 0, 2 * (inp_W - 1));
329331
} else {
330-
iy_reflected = reflect_coordinates(iy_reflected, -1, 2 * inp_H - 1);
331-
ix_reflected = reflect_coordinates(ix_reflected, -1, 2 * inp_W - 1);
332+
iy_reflected =
333+
reflect_coordinates(iy_reflected, -1, 2 * inp_H - 1);
334+
ix_reflected =
335+
reflect_coordinates(ix_reflected, -1, 2 * inp_W - 1);
332336
}
333337

334338
// Clip to ensure we're in bounds (reflection + clip for safety)
335-
int64_t iy_safe = static_cast<int64_t>(clip_coordinates(iy_reflected, inp_H));
336-
int64_t ix_safe = static_cast<int64_t>(clip_coordinates(ix_reflected, inp_W));
339+
int64_t iy_safe =
340+
static_cast<int64_t>(clip_coordinates(iy_reflected, inp_H));
341+
int64_t ix_safe =
342+
static_cast<int64_t>(clip_coordinates(ix_reflected, inp_W));
337343

338344
return in_data
339345
[in_channel_offset + iy_safe * in.strides()[2] +
@@ -375,7 +381,11 @@ void grid_sample_2d_bicubic_kernel_impl_nchw(
375381

376382
// Interpolate in y-direction
377383
CTYPE out_val = cubic_interp1d(
378-
coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty);
384+
coefficients[0],
385+
coefficients[1],
386+
coefficients[2],
387+
coefficients[3],
388+
ty);
379389

380390
// Write output in NCHW order
381391
const int64_t out_idx =
@@ -409,7 +419,8 @@ Tensor& grid_sampler_2d_out(
409419
"Failed to validate arguments and resize output tensor");
410420

411421
// Convert integer mode parameters to enums
412-
GridSamplerInterpolation mode = static_cast<GridSamplerInterpolation>(interpolation_mode);
422+
GridSamplerInterpolation mode =
423+
static_cast<GridSamplerInterpolation>(interpolation_mode);
413424
GridSamplerPadding padding = static_cast<GridSamplerPadding>(padding_mode);
414425

415426
// Validate mode and padding values
@@ -454,7 +465,6 @@ Tensor& grid_sampler_2d_out(
454465
// NOLINTEND(facebook-hte-ConstantArgumentPassByValue,
455466
// facebook-hte-ParameterMightThrowOnCopy)
456467

457-
458468
} // namespace native
459469
} // namespace executor
460470
} // namespace torch

kernels/portable/cpu/util/grid_sampler_2d_util.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ Error check_grid_sampler_2d_args_and_resize_out(
5555
InvalidArgument,
5656
"Input and grid must have same dtype");
5757

58+
// Input and output must have the same dtype
59+
ET_CHECK_OR_RETURN_ERROR(
60+
tensors_have_same_dtype(input, out),
61+
InvalidArgument,
62+
"Input and output must have the same dtype");
63+
5864
// Resize output tensor to [N, C, H_out, W_out]
5965
std::array<exec_aten::SizesType, 4> out_sizes = {
6066
static_cast<exec_aten::SizesType>(input.size(0)),
@@ -64,9 +70,7 @@ Error check_grid_sampler_2d_args_and_resize_out(
6470

6571
Error err = resize_tensor(out, {out_sizes.data(), 4});
6672
ET_CHECK_OR_RETURN_ERROR(
67-
err == Error::Ok,
68-
InvalidArgument,
69-
"Failed to resize output tensor");
73+
err == Error::Ok, InvalidArgument, "Failed to resize output tensor");
7074

7175
return Error::Ok;
7276
}

kernels/portable/cpu/util/grid_sampler_2d_util.h

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ namespace executor {
1818
// Ported from aten/src/ATen/native/GridSampler.h
1919
// note that these need to be in the SAME ORDER as the enum in GridSampler.h
2020
// as they are mapped to integer values (0, 1, 2) in this order
21-
enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
22-
enum class GridSamplerPadding {Zeros, Border, Reflection};
23-
21+
enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic };
22+
enum class GridSamplerPadding { Zeros, Border, Reflection };
2423

2524
// Ported from aten/src/ATen/native/GridSampler.h
2625
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
@@ -34,10 +33,8 @@ enum class GridSamplerPadding {Zeros, Border, Reflection};
3433
// +1 --> (size - 1) + 0.5 == size - 0.5
3534
// scale_factor = size / 2
3635
template <typename scalar_t>
37-
inline scalar_t grid_sampler_unnormalize(
38-
scalar_t coord,
39-
int64_t size,
40-
bool align_corners) {
36+
inline scalar_t
37+
grid_sampler_unnormalize(scalar_t coord, int64_t size, bool align_corners) {
4138
if (align_corners) {
4239
// unnormalize coord from [-1, 1] to [0, size - 1]
4340
return ((coord + 1) / 2) * (size - 1);
@@ -61,10 +58,8 @@ inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
6158
// The bounds are passed as twice their value so that half-integer values
6259
// can be represented as ints.
6360
template <typename scalar_t>
64-
inline scalar_t reflect_coordinates(
65-
scalar_t in,
66-
int64_t twice_low,
67-
int64_t twice_high) {
61+
inline scalar_t
62+
reflect_coordinates(scalar_t in, int64_t twice_low, int64_t twice_high) {
6863
if (twice_low == twice_high) {
6964
return static_cast<scalar_t>(0);
7065
}
@@ -120,14 +115,16 @@ inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
120115
}
121116

122117
// Ported from aten/src/ATen/native/UpSample.h
123-
// Cubic convolution function 2 (for points between 1 and 2 units from the point)
118+
// Cubic convolution function 2 (for points between 1 and 2 units from the
119+
// point)
124120
template <typename scalar_t>
125121
inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
126122
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
127123
}
128124

129125
// Ported from aten/src/ATen/native/UpSample.h
130-
// Computes the 4 cubic interpolation coefficients for a given position t in [0, 1]
126+
// Computes the 4 cubic interpolation coefficients for a given position t in [0,
127+
// 1]
131128
template <typename scalar_t>
132129
inline void get_cubic_upsample_coefficients(scalar_t coeffs[4], scalar_t t) {
133130
// Standard bicubic interpolation uses alpha = -0.75
@@ -145,12 +142,8 @@ inline void get_cubic_upsample_coefficients(scalar_t coeffs[4], scalar_t t) {
145142
// Ported from aten/src/ATen/native/UpSample.h
146143
// Performs 1D cubic interpolation given 4 points and a position t in [0, 1]
147144
template <typename scalar_t>
148-
inline scalar_t cubic_interp1d(
149-
scalar_t x0,
150-
scalar_t x1,
151-
scalar_t x2,
152-
scalar_t x3,
153-
scalar_t t) {
145+
inline scalar_t
146+
cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) {
154147
scalar_t coeffs[4];
155148
get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
156149

kernels/portable/test/op_grid_sampler_2d_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,4 +232,4 @@ def test_grid_sampler_2d_batch_processing(self):
232232

233233

234234
if __name__ == "__main__":
235-
unittest.main()
235+
unittest.main()

kernels/portable/test/targets.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def define_common_targets():
6868
op_test(name = "op_allclose_test")
6969
op_test(name = "op_div_test")
7070
op_test(name = "op_gelu_test")
71-
op_test(name = "op_grid_sampler_2d_test")
7271
op_test(name = "op_mul_test")
7372

7473
if is_xplat():

0 commit comments

Comments
 (0)