Skip to content

Commit 7f7264b

Browse files
committed
Fix out of bounds error. Attempt to fix buck2 issues
1 parent 8155889 commit 7f7264b

File tree

5 files changed

+37
-12
lines changed

5 files changed

+37
-12
lines changed

kernels/portable/cpu/op_grid_sampler_2d.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,31 @@ void grid_sample_2d_bilinear_kernel_impl_nchw(
115115
se_weight;
116116
}
117117
} else {
118-
// For border/reflection padding, coordinates are already clipped
118+
// For border/reflection padding, clip corner indices to valid range
119+
// Even though source coordinates are clipped, adding 1 can push corners out of bounds
120+
const int64_t ix_nw_safe = clip_coordinates(ix_nw, inp_W);
121+
const int64_t iy_nw_safe = clip_coordinates(iy_nw, inp_H);
122+
const int64_t ix_ne_safe = clip_coordinates(ix_ne, inp_W);
123+
const int64_t iy_ne_safe = clip_coordinates(iy_ne, inp_H);
124+
const int64_t ix_sw_safe = clip_coordinates(ix_sw, inp_W);
125+
const int64_t iy_sw_safe = clip_coordinates(iy_sw, inp_H);
126+
const int64_t ix_se_safe = clip_coordinates(ix_se, inp_W);
127+
const int64_t iy_se_safe = clip_coordinates(iy_se, inp_H);
119128
out_val = in_data
120-
[in_channel_offset + iy_nw * in.strides()[2] +
121-
ix_nw * in.strides()[3]] *
129+
[in_channel_offset + iy_nw_safe * in.strides()[2] +
130+
ix_nw_safe * in.strides()[3]] *
122131
nw_weight +
123132
in_data
124-
[in_channel_offset + iy_ne * in.strides()[2] +
125-
ix_ne * in.strides()[3]] *
133+
[in_channel_offset + iy_ne_safe * in.strides()[2] +
134+
ix_ne_safe * in.strides()[3]] *
126135
ne_weight +
127136
in_data
128-
[in_channel_offset + iy_sw * in.strides()[2] +
129-
ix_sw * in.strides()[3]] *
137+
[in_channel_offset + iy_sw_safe * in.strides()[2] +
138+
ix_sw_safe * in.strides()[3]] *
130139
sw_weight +
131140
in_data
132-
[in_channel_offset + iy_se * in.strides()[2] +
133-
ix_se * in.strides()[3]] *
141+
[in_channel_offset + iy_se_safe * in.strides()[2] +
142+
ix_se_safe * in.strides()[3]] *
134143
se_weight;
135144
}
136145

kernels/portable/cpu/util/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def define_common_targets():
3636
"//executorch/kernels/portable/cpu/util:elementwise_util",
3737
"//executorch/kernels/portable/cpu/util:upsample_util",
3838
"//executorch/kernels/portable/cpu/util:vectorized_math",
39+
"//executorch/kernels/portable/cpu/util:grid_sampler_2d_util",
3940
],
4041
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
4142
)
@@ -342,6 +343,16 @@ def define_common_targets():
342343
],
343344
)
344345

346+
runtime.cxx_library(
347+
name = "grid_sampler_2d_util",
348+
srcs = ["grid_sampler_2d_util.cpp"],
349+
exported_headers = ["grid_sampler_2d_util.h"],
350+
deps = [
351+
"//executorch/runtime/kernel:kernel_includes",
352+
],
353+
visibility = ["//executorch/kernels/portable/cpu/..."],
354+
)
355+
345356
# Utility functions that can be used by operators that perform reduction
346357
for aten_mode in get_aten_mode_options():
347358
suffix = "_aten" if aten_mode else ""

kernels/portable/test/op_grid_sampler_2d_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# pyre-unsafe
88

99
import itertools
10-
import os
1110
import unittest
1211

1312
import torch

kernels/portable/test/test_grid_sampler_2d_executorch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import itertools
1313
import sys
1414
import unittest
15-
from typing import Tuple
1615

1716
import torch
1817
import torch.nn as nn
@@ -87,7 +86,7 @@ def run_executorch_test(
8786
"forward"
8887
)
8988
if fwd_method is None:
90-
self.fail(f"Failed to load forward method")
89+
self.fail("Failed to load forward method")
9190
executorch_output = fwd_method.execute((input_tensor, grid))[0]
9291

9392
# Compare results

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,13 @@ ATEN_OPS = (
628628
"//executorch/runtime/core/exec_aten/util:tensor_util",
629629
],
630630
),
631+
op_target(
632+
name = "op_grid_sampler_2d",
633+
deps = [
634+
"//executorch/kernels/portable/cpu/util:grid_sampler_2d_util",
635+
"//executorch/runtime/core/exec_aten/util:tensor_util",
636+
],
637+
),
631638
op_target(
632639
name = "op_gt",
633640
deps = [

0 commit comments

Comments
 (0)