Skip to content

Commit f1c725e

Browse files
committed
Add aten::_upsample_bilinear2d_aa.out
Summary: Trying to resolve #7031
1 parent 1b4968f commit f1c725e

File tree

10 files changed

+1063
-0
lines changed

10 files changed

+1063
-0
lines changed
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <c10/util/irange.h>
9+
#include <algorithm>
10+
#include <cmath>
11+
12+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using executorch::aten::ArrayRef;
20+
using executorch::aten::SizesType;
21+
using std::optional;
22+
23+
namespace {
24+
25+
// Anti-aliasing filter for bilinear interpolation
26+
// Adapted from PyTorch's implementation
27+
template <typename T>
28+
inline T bilinear_aa_filter(T x) {
29+
x = std::abs(x);
30+
if (x < 1.0) {
31+
return 1.0 - x;
32+
}
33+
return 0.0;
34+
}
35+
36+
// Compute weights and indices for a single output pixel with anti-aliasing
37+
template <typename T>
38+
void compute_aa_weights_for_pixel(
39+
int64_t output_idx,
40+
int64_t input_size,
41+
int64_t output_size,
42+
bool align_corners,
43+
int64_t* indices,
44+
T* weights,
45+
int64_t* num_contributors) {
46+
const T scale = area_pixel_compute_scale<T>(
47+
input_size, output_size, align_corners, optional<double>());
48+
49+
const T center = area_pixel_compute_source_index<T>(
50+
scale, output_idx, align_corners, /*cubic=*/false);
51+
52+
// Anti-aliasing is only applied for downsampling (scale > 1.0)
53+
// For upsampling or identity scaling, use regular bilinear interpolation
54+
if (scale <= 1.0) {
55+
// Regular bilinear interpolation for upsampling
56+
int64_t x1 = static_cast<int64_t>(std::floor(center));
57+
int64_t x2 = std::min(x1 + 1, input_size - 1);
58+
59+
T lambda = center - x1;
60+
61+
if (x1 == x2) {
62+
// Single pixel case
63+
indices[0] = x1;
64+
weights[0] = 1.0;
65+
*num_contributors = 1;
66+
} else {
67+
// Two pixel interpolation
68+
indices[0] = x1;
69+
weights[0] = 1.0 - lambda;
70+
indices[1] = x2;
71+
weights[1] = lambda;
72+
*num_contributors = 2;
73+
}
74+
return;
75+
}
76+
77+
// Anti-aliasing for downsampling (scale > 1.0)
78+
const T support = scale;
79+
80+
// Find the range of input pixels that contribute
81+
const int64_t xmin = std::max(
82+
static_cast<int64_t>(center - support + 0.5), static_cast<int64_t>(0));
83+
const int64_t xmax =
84+
std::min(static_cast<int64_t>(center + support + 0.5), input_size);
85+
86+
T total_weight = 0.0;
87+
*num_contributors = std::min(xmax - xmin, int64_t(4));
88+
89+
// Compute anti-aliasing weights for contributing pixels
90+
for (int64_t j = 0; j < *num_contributors; ++j) {
91+
int64_t x = xmin + j;
92+
T weight = bilinear_aa_filter<T>((x - center + 0.5) / scale);
93+
indices[j] = x;
94+
weights[j] = weight;
95+
total_weight += weight;
96+
}
97+
98+
// Normalize weights
99+
if (total_weight > 0) {
100+
for (int64_t j = 0; j < *num_contributors; ++j) {
101+
weights[j] /= total_weight;
102+
}
103+
}
104+
}
105+
106+
template <typename CTYPE>
107+
void upsample_bilinear2d_aa_kernel_impl(
108+
KernelRuntimeContext& ctx,
109+
const Tensor& in,
110+
bool align_corners,
111+
const float scale_h,
112+
const float scale_w,
113+
Tensor& out) {
114+
const auto in_data = in.const_data_ptr<CTYPE>();
115+
auto out_data = out.mutable_data_ptr<CTYPE>();
116+
117+
const bool is_nchw =
118+
is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size());
119+
120+
if (is_nchw) {
121+
// NCHW layout
122+
for (int64_t n = 0; n < out.size(0); ++n) {
123+
for (int64_t c = 0; c < out.size(1); ++c) {
124+
const auto in_plane =
125+
in_data + (n * in.size(1) + c) * in.size(2) * in.size(3);
126+
auto out_plane =
127+
out_data + (n * out.size(1) + c) * out.size(2) * out.size(3);
128+
129+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
130+
// Compute height weights for this output row
131+
int64_t h_indices[4];
132+
float h_weights[4];
133+
int64_t h_num_contributors;
134+
compute_aa_weights_for_pixel<float>(
135+
oh,
136+
in.size(2),
137+
out.size(2),
138+
align_corners,
139+
h_indices,
140+
h_weights,
141+
&h_num_contributors);
142+
143+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
144+
// Compute width weights for this output column
145+
int64_t w_indices[4];
146+
float w_weights[4];
147+
int64_t w_num_contributors;
148+
compute_aa_weights_for_pixel<float>(
149+
ow,
150+
in.size(3),
151+
out.size(3),
152+
align_corners,
153+
w_indices,
154+
w_weights,
155+
&w_num_contributors);
156+
157+
CTYPE value = 0;
158+
159+
// Apply anti-aliased interpolation
160+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
161+
int64_t ih = h_indices[ih_idx];
162+
float h_weight = h_weights[ih_idx];
163+
164+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
165+
int64_t iw = w_indices[iw_idx];
166+
float w_weight = w_weights[iw_idx];
167+
168+
value += in_plane[ih * in.size(3) + iw] * h_weight * w_weight;
169+
}
170+
}
171+
172+
out_plane[oh * out.size(3) + ow] = value;
173+
}
174+
}
175+
}
176+
}
177+
} else {
178+
// NHWC layout
179+
for (int64_t n = 0; n < out.size(0); ++n) {
180+
const auto in_batch = in_data + n * in.size(1) * in.size(2) * in.size(3);
181+
auto out_batch = out_data + n * out.size(1) * out.size(2) * out.size(3);
182+
183+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
184+
// Compute height weights for this output row
185+
int64_t h_indices[4];
186+
float h_weights[4];
187+
int64_t h_num_contributors;
188+
compute_aa_weights_for_pixel<float>(
189+
oh,
190+
in.size(2),
191+
out.size(2),
192+
align_corners,
193+
h_indices,
194+
h_weights,
195+
&h_num_contributors);
196+
197+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
198+
// Compute width weights for this output column
199+
int64_t w_indices[4];
200+
float w_weights[4];
201+
int64_t w_num_contributors;
202+
compute_aa_weights_for_pixel<float>(
203+
ow,
204+
in.size(3),
205+
out.size(3),
206+
align_corners,
207+
w_indices,
208+
w_weights,
209+
&w_num_contributors);
210+
211+
for (int64_t c = 0; c < out.size(1); ++c) {
212+
CTYPE value = 0;
213+
214+
// Apply anti-aliased interpolation
215+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
216+
int64_t ih = h_indices[ih_idx];
217+
float h_weight = h_weights[ih_idx];
218+
219+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
220+
int64_t iw = w_indices[iw_idx];
221+
float w_weight = w_weights[iw_idx];
222+
223+
value += in_batch[(ih * in.size(3) + iw) * in.size(1) + c] *
224+
h_weight * w_weight;
225+
}
226+
}
227+
228+
out_batch[(oh * out.size(3) + ow) * out.size(1) + c] = value;
229+
}
230+
}
231+
}
232+
}
233+
}
234+
}
235+
236+
} // namespace
237+
238+
// Check function for anti-aliased bilinear upsampling
239+
bool check_upsample_bilinear2d_aa_args(
240+
const Tensor& in,
241+
const executorch::aten::OptionalArrayRef<int64_t>& output_size,
242+
const bool align_corners,
243+
const executorch::aten::OptionalArrayRef<double>& scale_factors,
244+
Tensor& out) {
245+
// Use the same checks as regular bilinear upsampling
246+
return check_upsample_bilinear2d_args(
247+
in, output_size, align_corners, scale_factors, out);
248+
}
249+
250+
// Main entry point for anti-aliased bilinear upsampling
251+
Tensor& _upsample_bilinear2d_aa_out(
252+
KernelRuntimeContext& ctx,
253+
const Tensor& in,
254+
const executorch::aten::ArrayRef<int64_t> output_size,
255+
bool align_corners,
256+
const std::optional<double> scale_h,
257+
const std::optional<double> scale_w,
258+
Tensor& out) {
259+
// Preconditions (checked in check_..._args):
260+
// In and out tensors have same dtype.
261+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
262+
// In and out tensors are NHWC or NCHW dim order.
263+
264+
// Custom validation for our specific interface (ArrayRef + optional
265+
// individual scales)
266+
ET_KERNEL_CHECK(ctx, in.dim() == 4, InvalidArgument, out);
267+
ET_KERNEL_CHECK(ctx, out.dim() == 4, InvalidArgument, out);
268+
ET_KERNEL_CHECK(
269+
ctx, in.scalar_type() == out.scalar_type(), InvalidArgument, out);
270+
ET_KERNEL_CHECK(ctx, output_size.size() == 2, InvalidArgument, out);
271+
ET_KERNEL_CHECK(
272+
ctx, output_size[0] > 0 && output_size[1] > 0, InvalidArgument, out);
273+
274+
// Ensure output tensor has correct dimensions
275+
ET_KERNEL_CHECK(
276+
ctx, out.size(0) == in.size(0), InvalidArgument, out); // batch
277+
ET_KERNEL_CHECK(
278+
ctx, out.size(1) == in.size(1), InvalidArgument, out); // channels
279+
ET_KERNEL_CHECK(
280+
ctx, out.size(2) == output_size[0], InvalidArgument, out); // height
281+
ET_KERNEL_CHECK(
282+
ctx, out.size(3) == output_size[1], InvalidArgument, out); // width
283+
284+
// Compute final scales - use provided scales if available, otherwise compute
285+
// from sizes
286+
double final_scale_h, final_scale_w;
287+
if (scale_h.has_value() && scale_w.has_value()) {
288+
final_scale_h = scale_h.value();
289+
final_scale_w = scale_w.value();
290+
} else {
291+
// Compute scales from input/output sizes
292+
final_scale_h =
293+
static_cast<double>(output_size[0]) / static_cast<double>(in.size(2));
294+
final_scale_w =
295+
static_cast<double>(output_size[1]) / static_cast<double>(in.size(3));
296+
}
297+
298+
const auto kernel_scale_h = area_pixel_compute_scale<double>(
299+
in.sizes()[2], out.sizes()[2], align_corners, final_scale_h);
300+
const auto kernel_scale_w = area_pixel_compute_scale<double>(
301+
in.sizes()[3], out.sizes()[3], align_corners, final_scale_w);
302+
303+
ET_SWITCH_REALHBF16_TYPES(
304+
in.scalar_type(), ctx, "_upsample_bilinear2d_aa.out", CTYPE, [&]() {
305+
upsample_bilinear2d_aa_kernel_impl<CTYPE>(
306+
ctx, in, align_corners, kernel_scale_h, kernel_scale_w, out);
307+
});
308+
309+
return out;
310+
}
311+
312+
} // namespace native
313+
} // namespace executor
314+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,11 @@
965965
- arg_meta: null
966966
kernel_name: torch::executor::upsample_bilinear2d_vec_out
967967

968+
- op: _upsample_bilinear2d_aa.out
969+
kernels:
970+
- arg_meta: null
971+
kernel_name: torch::executor::_upsample_bilinear2d_aa_out
972+
968973
- op: upsample_nearest2d.vec_out
969974
kernels:
970975
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ runtime.cxx_library(
2020
deps = [
2121
"//executorch/extension/aten_util:aten_bridge",
2222
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
23+
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d_aa",
2324
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",
2425
"//executorch/runtime/core/exec_aten:lib",
2526
],

0 commit comments

Comments
 (0)