Skip to content

Commit ed61348

Browse files
authored
Add aten::_upsample_bilinear2d_aa.out (#13458)
Summary: Trying to resolve #7031 Vibe-coded using the existing non-alias version in ET and Aten implementation in pytorch as reference, along with reference unittests in pytorch core Test Plan: 1. Run https://gist.github.com/mergennachin/9b02aee4feb5acc83e71d8f902f5cca1 And then call `./cmake-out/executor_runner minicpmv_preprocessor.pte` 2. https://gist.github.com/mergennachin/a24e4509804de99caf906c9b79ea45fc
1 parent 2f4b704 commit ed61348

File tree

11 files changed

+1275
-0
lines changed

11 files changed

+1275
-0
lines changed
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <c10/util/irange.h>
9+
#include <algorithm>
10+
#include <cmath>
11+
12+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using executorch::aten::ArrayRef;
20+
using executorch::aten::SizesType;
21+
using std::optional;
22+
23+
namespace {
24+
25+
// Anti-aliasing filter matching PyTorch's implementation exactly
26+
template <typename T>
27+
inline T bilinear_aa_filter(T x) {
28+
x = std::abs(x);
29+
return (x < static_cast<T>(1.0)) ? (static_cast<T>(1.0) - x)
30+
: static_cast<T>(0.0);
31+
}
32+
33+
// Compute anti-aliasing weights exactly matching PyTorch's algorithm
34+
template <typename T>
35+
void compute_aa_weights_for_pixel(
36+
int64_t output_idx,
37+
T scale,
38+
int64_t input_size,
39+
int64_t* indices,
40+
T* weights,
41+
int64_t* num_contributors) {
42+
// Use the provided scale directly instead of recalculating
43+
44+
// PyTorch's center calculation for anti-aliasing
45+
// Always uses scale * (i + 0.5) for anti-aliasing, regardless of
46+
// align_corners
47+
const T center = scale * (output_idx + static_cast<T>(0.5));
48+
49+
// PyTorch's support calculation for bilinear anti-aliasing
50+
// interp_size = 2 for bilinear, so base support = 1.0
51+
const T support = (scale >= static_cast<T>(1.0))
52+
? (static_cast<T>(1.0) * scale)
53+
: static_cast<T>(1.0);
54+
55+
// PyTorch's exact range calculation
56+
const int64_t xmin = std::max(
57+
static_cast<int64_t>(center - support + static_cast<T>(0.5)),
58+
static_cast<int64_t>(0));
59+
const int64_t xmax = std::min(
60+
static_cast<int64_t>(center + support + static_cast<T>(0.5)), input_size);
61+
62+
*num_contributors = std::min(xmax - xmin, static_cast<int64_t>(4));
63+
64+
// PyTorch's weight computation
65+
T total_weight = static_cast<T>(0.0);
66+
const T invscale = (scale >= static_cast<T>(1.0))
67+
? (static_cast<T>(1.0) / scale)
68+
: static_cast<T>(1.0);
69+
70+
for (int64_t j = 0; j < *num_contributors; ++j) {
71+
int64_t x = xmin + j;
72+
// PyTorch's exact weight formula: (j + xmin - center + 0.5) * invscale
73+
T arg = (static_cast<T>(j) + static_cast<T>(xmin) - center +
74+
static_cast<T>(0.5)) *
75+
invscale;
76+
T weight = bilinear_aa_filter<T>(arg);
77+
indices[j] = x;
78+
weights[j] = weight;
79+
total_weight += weight;
80+
}
81+
82+
// Normalize weights to sum to 1 (PyTorch does this)
83+
if (total_weight > static_cast<T>(0.0)) {
84+
for (int64_t j = 0; j < *num_contributors; ++j) {
85+
weights[j] /= total_weight;
86+
}
87+
}
88+
89+
// Clear unused weight slots
90+
for (int64_t j = *num_contributors; j < 4; ++j) {
91+
weights[j] = static_cast<T>(0.0);
92+
}
93+
}
94+
95+
template <typename CTYPE>
96+
void upsample_bilinear2d_aa_kernel_impl(
97+
KernelRuntimeContext& ctx,
98+
const Tensor& in,
99+
bool align_corners,
100+
const float scale_h,
101+
const float scale_w,
102+
Tensor& out) {
103+
const auto in_data = in.const_data_ptr<CTYPE>();
104+
auto out_data = out.mutable_data_ptr<CTYPE>();
105+
106+
const bool is_nchw =
107+
is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size());
108+
109+
if (is_nchw) {
110+
// NCHW layout
111+
for (int64_t n = 0; n < out.size(0); ++n) {
112+
for (int64_t c = 0; c < out.size(1); ++c) {
113+
const auto in_plane =
114+
in_data + (n * in.size(1) + c) * in.size(2) * in.size(3);
115+
auto out_plane =
116+
out_data + (n * out.size(1) + c) * out.size(2) * out.size(3);
117+
118+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
119+
// Compute height weights for this output row
120+
int64_t h_indices[4];
121+
float h_weights[4];
122+
int64_t h_num_contributors;
123+
compute_aa_weights_for_pixel<float>(
124+
oh,
125+
scale_h,
126+
in.size(2),
127+
h_indices,
128+
h_weights,
129+
&h_num_contributors);
130+
131+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
132+
// Compute width weights for this output column
133+
int64_t w_indices[4];
134+
float w_weights[4];
135+
int64_t w_num_contributors;
136+
compute_aa_weights_for_pixel<float>(
137+
ow,
138+
scale_w,
139+
in.size(3),
140+
w_indices,
141+
w_weights,
142+
&w_num_contributors);
143+
144+
CTYPE value = 0;
145+
146+
// Apply anti-aliased interpolation
147+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
148+
int64_t ih = h_indices[ih_idx];
149+
float h_weight = h_weights[ih_idx];
150+
151+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
152+
int64_t iw = w_indices[iw_idx];
153+
float w_weight = w_weights[iw_idx];
154+
155+
value += in_plane[ih * in.size(3) + iw] * h_weight * w_weight;
156+
}
157+
}
158+
159+
out_plane[oh * out.size(3) + ow] = value;
160+
}
161+
}
162+
}
163+
}
164+
} else {
165+
// NHWC layout
166+
for (int64_t n = 0; n < out.size(0); ++n) {
167+
const auto in_batch = in_data + n * in.size(1) * in.size(2) * in.size(3);
168+
auto out_batch = out_data + n * out.size(1) * out.size(2) * out.size(3);
169+
170+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
171+
// Compute height weights for this output row
172+
int64_t h_indices[4];
173+
float h_weights[4];
174+
int64_t h_num_contributors;
175+
compute_aa_weights_for_pixel<float>(
176+
oh, scale_h, in.size(2), h_indices, h_weights, &h_num_contributors);
177+
178+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
179+
// Compute width weights for this output column
180+
int64_t w_indices[4];
181+
float w_weights[4];
182+
int64_t w_num_contributors;
183+
compute_aa_weights_for_pixel<float>(
184+
ow,
185+
scale_w,
186+
in.size(3),
187+
w_indices,
188+
w_weights,
189+
&w_num_contributors);
190+
191+
for (int64_t c = 0; c < out.size(1); ++c) {
192+
CTYPE value = 0;
193+
194+
// Apply anti-aliased interpolation
195+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
196+
int64_t ih = h_indices[ih_idx];
197+
float h_weight = h_weights[ih_idx];
198+
199+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
200+
int64_t iw = w_indices[iw_idx];
201+
float w_weight = w_weights[iw_idx];
202+
203+
value += in_batch[(ih * in.size(3) + iw) * in.size(1) + c] *
204+
h_weight * w_weight;
205+
}
206+
}
207+
208+
out_batch[(oh * out.size(3) + ow) * out.size(1) + c] = value;
209+
}
210+
}
211+
}
212+
}
213+
}
214+
}
215+
216+
} // namespace
217+
218+
// Check function for anti-aliased bilinear upsampling
219+
bool check_upsample_bilinear2d_aa_args(
220+
const Tensor& in,
221+
const executorch::aten::OptionalArrayRef<int64_t>& output_size,
222+
const bool align_corners,
223+
const executorch::aten::OptionalArrayRef<double>& scale_factors,
224+
Tensor& out) {
225+
// Use the same checks as regular bilinear upsampling
226+
return check_upsample_bilinear2d_args(
227+
in, output_size, align_corners, scale_factors, out);
228+
}
229+
230+
// Main entry point for anti-aliased bilinear upsampling
231+
Tensor& _upsample_bilinear2d_aa_out(
232+
KernelRuntimeContext& ctx,
233+
const Tensor& in,
234+
const executorch::aten::ArrayRef<int64_t> output_size,
235+
bool align_corners,
236+
const std::optional<double> scale_h,
237+
const std::optional<double> scale_w,
238+
Tensor& out) {
239+
// Preconditions (checked in check_..._args):
240+
// In and out tensors have same dtype.
241+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
242+
// In and out tensors are NHWC or NCHW dim order.
243+
244+
// Custom validation for our specific interface (ArrayRef + optional
245+
// individual scales)
246+
ET_KERNEL_CHECK(ctx, in.dim() == 4, InvalidArgument, out);
247+
ET_KERNEL_CHECK(ctx, out.dim() == 4, InvalidArgument, out);
248+
ET_KERNEL_CHECK(
249+
ctx, in.scalar_type() == out.scalar_type(), InvalidArgument, out);
250+
ET_KERNEL_CHECK(ctx, output_size.size() == 2, InvalidArgument, out);
251+
ET_KERNEL_CHECK(
252+
ctx, output_size[0] > 0 && output_size[1] > 0, InvalidArgument, out);
253+
254+
// Ensure output tensor has correct dimensions
255+
ET_KERNEL_CHECK(
256+
ctx, out.size(0) == in.size(0), InvalidArgument, out); // batch
257+
ET_KERNEL_CHECK(
258+
ctx, out.size(1) == in.size(1), InvalidArgument, out); // channels
259+
ET_KERNEL_CHECK(
260+
ctx, out.size(2) == output_size[0], InvalidArgument, out); // height
261+
ET_KERNEL_CHECK(
262+
ctx, out.size(3) == output_size[1], InvalidArgument, out); // width
263+
264+
// Compute final scales - use provided scales if available, otherwise compute
265+
// from sizes
266+
double final_scale_h, final_scale_w;
267+
if (scale_h.has_value() && scale_w.has_value()) {
268+
final_scale_h = scale_h.value();
269+
final_scale_w = scale_w.value();
270+
} else {
271+
// Compute scales from input/output sizes
272+
final_scale_h =
273+
static_cast<double>(output_size[0]) / static_cast<double>(in.size(2));
274+
final_scale_w =
275+
static_cast<double>(output_size[1]) / static_cast<double>(in.size(3));
276+
}
277+
278+
const auto kernel_scale_h = area_pixel_compute_scale<double>(
279+
in.sizes()[2], out.sizes()[2], align_corners, final_scale_h);
280+
const auto kernel_scale_w = area_pixel_compute_scale<double>(
281+
in.sizes()[3], out.sizes()[3], align_corners, final_scale_w);
282+
283+
ET_SWITCH_REALHBF16_TYPES(
284+
in.scalar_type(), ctx, "_upsample_bilinear2d_aa.out", CTYPE, [&]() {
285+
upsample_bilinear2d_aa_kernel_impl<CTYPE>(
286+
ctx, in, align_corners, kernel_scale_h, kernel_scale_w, out);
287+
});
288+
289+
return out;
290+
}
291+
292+
} // namespace native
293+
} // namespace executor
294+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,11 @@
965965
- arg_meta: null
966966
kernel_name: torch::executor::upsample_bilinear2d_vec_out
967967

968+
- op: _upsample_bilinear2d_aa.out
969+
kernels:
970+
- arg_meta: null
971+
kernel_name: torch::executor::_upsample_bilinear2d_aa_out
972+
968973
- op: upsample_nearest2d.vec_out
969974
kernels:
970975
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ runtime.cxx_library(
2020
deps = [
2121
"//executorch/extension/aten_util:aten_bridge",
2222
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
23+
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d_aa",
2324
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",
2425
"//executorch/runtime/core/exec_aten:lib",
2526
],

0 commit comments

Comments
 (0)