Skip to content

Commit cd8f804

Browse files
committed
Add aten::_upsample_bilinear2d_aa.out
Summary: Trying to resolve #7031
1 parent f2a7f46 commit cd8f804

File tree

9 files changed

+976
-0
lines changed

9 files changed

+976
-0
lines changed
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <c10/util/irange.h>
9+
#include <algorithm>
10+
#include <cmath>
11+
12+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using executorch::aten::ArrayRef;
20+
using executorch::aten::SizesType;
21+
using std::optional;
22+
23+
namespace {
24+
25+
// Anti-aliasing filter for bilinear interpolation
26+
// Adapted from PyTorch's implementation
27+
template <typename T>
28+
inline T bilinear_aa_filter(T x) {
29+
x = std::abs(x);
30+
if (x < 1.0) {
31+
return 1.0 - x;
32+
}
33+
return 0.0;
34+
}
35+
36+
// Compute weights and indices for a single output pixel with anti-aliasing
37+
template <typename T>
38+
void compute_aa_weights_for_pixel(
39+
int64_t output_idx,
40+
int64_t input_size,
41+
int64_t output_size,
42+
bool align_corners,
43+
int64_t* indices,
44+
T* weights,
45+
int64_t* num_contributors) {
46+
const T scale = area_pixel_compute_scale<T>(
47+
input_size, output_size, align_corners, optional<double>());
48+
49+
const T center = area_pixel_compute_source_index<T>(
50+
scale, output_idx, align_corners, /*cubic=*/false);
51+
52+
// Support is the filter radius - for downsampling, we need a larger filter
53+
const T support = (scale >= 1.0) ? scale : 1.0;
54+
55+
// Find the range of input pixels that contribute
56+
const int64_t xmin = std::max(
57+
static_cast<int64_t>(center - support + 0.5), static_cast<int64_t>(0));
58+
const int64_t xmax =
59+
std::min(static_cast<int64_t>(center + support + 0.5), input_size);
60+
61+
T total_weight = 0.0;
62+
*num_contributors = std::min(xmax - xmin, int64_t(4));
63+
64+
// Compute weights for contributing pixels
65+
for (int64_t j = 0; j < *num_contributors; ++j) {
66+
int64_t x = xmin + j;
67+
T weight = bilinear_aa_filter<T>(
68+
(x - center + 0.5) / (scale >= 1.0 ? scale : 1.0));
69+
indices[j] = x;
70+
weights[j] = weight;
71+
total_weight += weight;
72+
}
73+
74+
// Normalize weights
75+
if (total_weight > 0) {
76+
for (int64_t j = 0; j < *num_contributors; ++j) {
77+
weights[j] /= total_weight;
78+
}
79+
}
80+
}
81+
82+
template <typename CTYPE>
83+
void upsample_bilinear2d_aa_kernel_impl(
84+
KernelRuntimeContext& ctx,
85+
const Tensor& in,
86+
bool align_corners,
87+
const float scale_h,
88+
const float scale_w,
89+
Tensor& out) {
90+
const auto in_data = in.const_data_ptr<CTYPE>();
91+
auto out_data = out.mutable_data_ptr<CTYPE>();
92+
93+
const bool is_nchw =
94+
is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size());
95+
96+
if (is_nchw) {
97+
// NCHW layout
98+
for (int64_t n = 0; n < out.size(0); ++n) {
99+
for (int64_t c = 0; c < out.size(1); ++c) {
100+
const auto in_plane =
101+
in_data + (n * in.size(1) + c) * in.size(2) * in.size(3);
102+
auto out_plane =
103+
out_data + (n * out.size(1) + c) * out.size(2) * out.size(3);
104+
105+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
106+
// Compute height weights for this output row
107+
int64_t h_indices[4];
108+
float h_weights[4];
109+
int64_t h_num_contributors;
110+
compute_aa_weights_for_pixel<float>(
111+
oh,
112+
in.size(2),
113+
out.size(2),
114+
align_corners,
115+
h_indices,
116+
h_weights,
117+
&h_num_contributors);
118+
119+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
120+
// Compute width weights for this output column
121+
int64_t w_indices[4];
122+
float w_weights[4];
123+
int64_t w_num_contributors;
124+
compute_aa_weights_for_pixel<float>(
125+
ow,
126+
in.size(3),
127+
out.size(3),
128+
align_corners,
129+
w_indices,
130+
w_weights,
131+
&w_num_contributors);
132+
133+
CTYPE value = 0;
134+
135+
// Apply anti-aliased interpolation
136+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
137+
int64_t ih = h_indices[ih_idx];
138+
float h_weight = h_weights[ih_idx];
139+
140+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
141+
int64_t iw = w_indices[iw_idx];
142+
float w_weight = w_weights[iw_idx];
143+
144+
value += in_plane[ih * in.size(3) + iw] * h_weight * w_weight;
145+
}
146+
}
147+
148+
out_plane[oh * out.size(3) + ow] = value;
149+
}
150+
}
151+
}
152+
}
153+
} else {
154+
// NHWC layout
155+
for (int64_t n = 0; n < out.size(0); ++n) {
156+
const auto in_batch = in_data + n * in.size(1) * in.size(2) * in.size(3);
157+
auto out_batch = out_data + n * out.size(1) * out.size(2) * out.size(3);
158+
159+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
160+
// Compute height weights for this output row
161+
int64_t h_indices[4];
162+
float h_weights[4];
163+
int64_t h_num_contributors;
164+
compute_aa_weights_for_pixel<float>(
165+
oh,
166+
in.size(2),
167+
out.size(2),
168+
align_corners,
169+
h_indices,
170+
h_weights,
171+
&h_num_contributors);
172+
173+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
174+
// Compute width weights for this output column
175+
int64_t w_indices[4];
176+
float w_weights[4];
177+
int64_t w_num_contributors;
178+
compute_aa_weights_for_pixel<float>(
179+
ow,
180+
in.size(3),
181+
out.size(3),
182+
align_corners,
183+
w_indices,
184+
w_weights,
185+
&w_num_contributors);
186+
187+
for (int64_t c = 0; c < out.size(1); ++c) {
188+
CTYPE value = 0;
189+
190+
// Apply anti-aliased interpolation
191+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
192+
int64_t ih = h_indices[ih_idx];
193+
float h_weight = h_weights[ih_idx];
194+
195+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
196+
int64_t iw = w_indices[iw_idx];
197+
float w_weight = w_weights[iw_idx];
198+
199+
value += in_batch[(ih * in.size(3) + iw) * in.size(1) + c] *
200+
h_weight * w_weight;
201+
}
202+
}
203+
204+
out_batch[(oh * out.size(3) + ow) * out.size(1) + c] = value;
205+
}
206+
}
207+
}
208+
}
209+
}
210+
}
211+
212+
} // namespace
213+
214+
// Check function for anti-aliased bilinear upsampling
215+
bool check_upsample_bilinear2d_aa_args(
216+
const Tensor& in,
217+
const executorch::aten::OptionalArrayRef<int64_t>& output_size,
218+
const bool align_corners,
219+
const executorch::aten::OptionalArrayRef<double>& scale_factors,
220+
Tensor& out) {
221+
// Use the same checks as regular bilinear upsampling
222+
return check_upsample_bilinear2d_args(
223+
in, output_size, align_corners, scale_factors, out);
224+
}
225+
226+
// Main entry point for anti-aliased bilinear upsampling
227+
Tensor& _upsample_bilinear2d_aa_out(
228+
KernelRuntimeContext& ctx,
229+
const Tensor& in,
230+
const executorch::aten::OptionalArrayRef<int64_t> output_size,
231+
bool align_corners,
232+
const executorch::aten::OptionalArrayRef<double> scale_factors,
233+
Tensor& out) {
234+
// Preconditions (checked in check_..._args):
235+
// In and out tensors have same dtype.
236+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
237+
// In and out tensors are NHWC or NCHW dim order.
238+
ET_KERNEL_CHECK(
239+
ctx,
240+
check_upsample_bilinear2d_aa_args(
241+
in, output_size, align_corners, scale_factors, out),
242+
InvalidArgument,
243+
out);
244+
245+
double scale_h, scale_w;
246+
247+
ET_KERNEL_CHECK_MSG(
248+
ctx,
249+
resize_upsample_2d(
250+
in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok,
251+
InvalidArgument,
252+
out,
253+
"Failed to resize output tensor");
254+
255+
const auto kernel_scale_h = area_pixel_compute_scale<double>(
256+
in.sizes()[2], out.sizes()[2], align_corners, scale_h);
257+
const auto kernel_scale_w = area_pixel_compute_scale<double>(
258+
in.sizes()[3], out.sizes()[3], align_corners, scale_w);
259+
260+
ET_SWITCH_REALHBF16_TYPES(
261+
in.scalar_type(), ctx, "_upsample_bilinear2d_aa.out", CTYPE, [&]() {
262+
upsample_bilinear2d_aa_kernel_impl<CTYPE>(
263+
ctx, in, align_corners, kernel_scale_h, kernel_scale_w, out);
264+
});
265+
266+
return out;
267+
}
268+
269+
} // namespace native
270+
} // namespace executor
271+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,11 @@
965965
- arg_meta: null
966966
kernel_name: torch::executor::upsample_bilinear2d_vec_out
967967

968+
- op: _upsample_bilinear2d_aa.out
969+
kernels:
970+
- arg_meta: null
971+
kernel_name: torch::executor::_upsample_bilinear2d_aa_out
972+
968973
- op: upsample_nearest2d.vec_out
969974
kernels:
970975
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ runtime.cxx_library(
2020
deps = [
2121
"//executorch/extension/aten_util:aten_bridge",
2222
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
23+
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d_aa",
2324
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",
2425
"//executorch/runtime/core/exec_aten:lib",
2526
],

0 commit comments

Comments
 (0)