Skip to content

Commit 0f5a109

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Add aten::_upsample_bilinear2d_aa.out (#13458)
Summary: Trying to resolve #7031 Vibe-coded using the existing non-alias version in ET and Aten implementation in pytorch as reference, along with reference unittests in pytorch core Test Plan: 1. Run https://gist.github.com/mergennachin/9b02aee4feb5acc83e71d8f902f5cca1 And then call `./cmake-out/executor_runner minicpmv_preprocessor.pte` 2. https://gist.github.com/mergennachin/a24e4509804de99caf906c9b79ea45fc Reviewed By: manuelcandales Differential Revision: D80343748 Pulled By: mergennachin
1 parent bc5186c commit 0f5a109

File tree

11 files changed

+1298
-0
lines changed

11 files changed

+1298
-0
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <c10/util/irange.h>
9+
#include <algorithm>
10+
#include <cmath>
11+
12+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using executorch::aten::ArrayRef;
20+
using executorch::aten::SizesType;
21+
using std::optional;
22+
23+
namespace {
24+
25+
// Anti-aliasing filter matching PyTorch's implementation exactly
26+
template <typename T>
27+
inline T bilinear_aa_filter(T x) {
28+
x = std::abs(x);
29+
return (x < static_cast<T>(1.0)) ? (static_cast<T>(1.0) - x)
30+
: static_cast<T>(0.0);
31+
}
32+
33+
// Compute anti-aliasing weights exactly matching PyTorch's algorithm
34+
template <typename T>
35+
void compute_aa_weights_for_pixel(
36+
int64_t output_idx,
37+
int64_t input_size,
38+
int64_t output_size,
39+
bool align_corners,
40+
int64_t* indices,
41+
T* weights,
42+
int64_t* num_contributors) {
43+
44+
const T scale = area_pixel_compute_scale<T>(
45+
input_size, output_size, align_corners, optional<double>());
46+
47+
// PyTorch's center calculation for anti-aliasing
48+
// Always uses scale * (i + 0.5) for anti-aliasing, regardless of align_corners
49+
const T center = scale * (output_idx + static_cast<T>(0.5));
50+
51+
// PyTorch's support calculation for bilinear anti-aliasing
52+
// interp_size = 2 for bilinear, so base support = 1.0
53+
const T support = (scale >= static_cast<T>(1.0))
54+
? (static_cast<T>(1.0) * scale)
55+
: static_cast<T>(1.0);
56+
57+
// PyTorch's exact range calculation
58+
const int64_t xmin = std::max(
59+
static_cast<int64_t>(center - support + static_cast<T>(0.5)),
60+
static_cast<int64_t>(0));
61+
const int64_t xmax = std::min(
62+
static_cast<int64_t>(center + support + static_cast<T>(0.5)), input_size);
63+
64+
*num_contributors = std::min(xmax - xmin, static_cast<int64_t>(4));
65+
66+
// PyTorch's weight computation
67+
T total_weight = static_cast<T>(0.0);
68+
const T invscale = (scale >= static_cast<T>(1.0))
69+
? (static_cast<T>(1.0) / scale)
70+
: static_cast<T>(1.0);
71+
72+
for (int64_t j = 0; j < *num_contributors; ++j) {
73+
int64_t x = xmin + j;
74+
// PyTorch's exact weight formula: (j + xmin - center + 0.5) * invscale
75+
T arg = (static_cast<T>(j) + static_cast<T>(xmin) - center +
76+
static_cast<T>(0.5)) * invscale;
77+
T weight = bilinear_aa_filter<T>(arg);
78+
indices[j] = x;
79+
weights[j] = weight;
80+
total_weight += weight;
81+
}
82+
83+
// Normalize weights to sum to 1 (PyTorch does this)
84+
if (total_weight > static_cast<T>(0.0)) {
85+
for (int64_t j = 0; j < *num_contributors; ++j) {
86+
weights[j] /= total_weight;
87+
}
88+
}
89+
90+
// Clear unused weight slots
91+
for (int64_t j = *num_contributors; j < 4; ++j) {
92+
weights[j] = static_cast<T>(0.0);
93+
}
94+
}
95+
96+
template <typename CTYPE>
97+
void upsample_bilinear2d_aa_kernel_impl(
98+
KernelRuntimeContext& ctx,
99+
const Tensor& in,
100+
bool align_corners,
101+
const float scale_h,
102+
const float scale_w,
103+
Tensor& out) {
104+
const auto in_data = in.const_data_ptr<CTYPE>();
105+
auto out_data = out.mutable_data_ptr<CTYPE>();
106+
107+
const bool is_nchw =
108+
is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size());
109+
110+
if (is_nchw) {
111+
// NCHW layout
112+
for (int64_t n = 0; n < out.size(0); ++n) {
113+
for (int64_t c = 0; c < out.size(1); ++c) {
114+
const auto in_plane =
115+
in_data + (n * in.size(1) + c) * in.size(2) * in.size(3);
116+
auto out_plane =
117+
out_data + (n * out.size(1) + c) * out.size(2) * out.size(3);
118+
119+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
120+
// Compute height weights for this output row
121+
int64_t h_indices[4];
122+
float h_weights[4];
123+
int64_t h_num_contributors;
124+
compute_aa_weights_for_pixel<float>(
125+
oh,
126+
in.size(2),
127+
out.size(2),
128+
align_corners,
129+
h_indices,
130+
h_weights,
131+
&h_num_contributors);
132+
133+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
134+
// Compute width weights for this output column
135+
int64_t w_indices[4];
136+
float w_weights[4];
137+
int64_t w_num_contributors;
138+
compute_aa_weights_for_pixel<float>(
139+
ow,
140+
in.size(3),
141+
out.size(3),
142+
align_corners,
143+
w_indices,
144+
w_weights,
145+
&w_num_contributors);
146+
147+
CTYPE value = 0;
148+
149+
// Apply anti-aliased interpolation
150+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
151+
int64_t ih = h_indices[ih_idx];
152+
float h_weight = h_weights[ih_idx];
153+
154+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
155+
int64_t iw = w_indices[iw_idx];
156+
float w_weight = w_weights[iw_idx];
157+
158+
value += in_plane[ih * in.size(3) + iw] * h_weight * w_weight;
159+
}
160+
}
161+
162+
out_plane[oh * out.size(3) + ow] = value;
163+
}
164+
}
165+
}
166+
}
167+
} else {
168+
// NHWC layout
169+
for (int64_t n = 0; n < out.size(0); ++n) {
170+
const auto in_batch = in_data + n * in.size(1) * in.size(2) * in.size(3);
171+
auto out_batch = out_data + n * out.size(1) * out.size(2) * out.size(3);
172+
173+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
174+
// Compute height weights for this output row
175+
int64_t h_indices[4];
176+
float h_weights[4];
177+
int64_t h_num_contributors;
178+
compute_aa_weights_for_pixel<float>(
179+
oh,
180+
in.size(2),
181+
out.size(2),
182+
align_corners,
183+
h_indices,
184+
h_weights,
185+
&h_num_contributors);
186+
187+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
188+
// Compute width weights for this output column
189+
int64_t w_indices[4];
190+
float w_weights[4];
191+
int64_t w_num_contributors;
192+
compute_aa_weights_for_pixel<float>(
193+
ow,
194+
in.size(3),
195+
out.size(3),
196+
align_corners,
197+
w_indices,
198+
w_weights,
199+
&w_num_contributors);
200+
201+
for (int64_t c = 0; c < out.size(1); ++c) {
202+
CTYPE value = 0;
203+
204+
// Apply anti-aliased interpolation
205+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
206+
int64_t ih = h_indices[ih_idx];
207+
float h_weight = h_weights[ih_idx];
208+
209+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
210+
int64_t iw = w_indices[iw_idx];
211+
float w_weight = w_weights[iw_idx];
212+
213+
value += in_batch[(ih * in.size(3) + iw) * in.size(1) + c] *
214+
h_weight * w_weight;
215+
}
216+
}
217+
218+
out_batch[(oh * out.size(3) + ow) * out.size(1) + c] = value;
219+
}
220+
}
221+
}
222+
}
223+
}
224+
}
225+
226+
} // namespace
227+
228+
// Check function for anti-aliased bilinear upsampling
229+
bool check_upsample_bilinear2d_aa_args(
230+
const Tensor& in,
231+
const executorch::aten::OptionalArrayRef<int64_t>& output_size,
232+
const bool align_corners,
233+
const executorch::aten::OptionalArrayRef<double>& scale_factors,
234+
Tensor& out) {
235+
// Use the same checks as regular bilinear upsampling
236+
return check_upsample_bilinear2d_args(
237+
in, output_size, align_corners, scale_factors, out);
238+
}
239+
240+
// Main entry point for anti-aliased bilinear upsampling
241+
Tensor& _upsample_bilinear2d_aa_out(
242+
KernelRuntimeContext& ctx,
243+
const Tensor& in,
244+
const executorch::aten::ArrayRef<int64_t> output_size,
245+
bool align_corners,
246+
const std::optional<double> scale_h,
247+
const std::optional<double> scale_w,
248+
Tensor& out) {
249+
// Preconditions (checked in check_..._args):
250+
// In and out tensors have same dtype.
251+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
252+
// In and out tensors are NHWC or NCHW dim order.
253+
254+
// Custom validation for our specific interface (ArrayRef + optional
255+
// individual scales)
256+
ET_KERNEL_CHECK(ctx, in.dim() == 4, InvalidArgument, out);
257+
ET_KERNEL_CHECK(ctx, out.dim() == 4, InvalidArgument, out);
258+
ET_KERNEL_CHECK(
259+
ctx, in.scalar_type() == out.scalar_type(), InvalidArgument, out);
260+
ET_KERNEL_CHECK(ctx, output_size.size() == 2, InvalidArgument, out);
261+
ET_KERNEL_CHECK(
262+
ctx, output_size[0] > 0 && output_size[1] > 0, InvalidArgument, out);
263+
264+
// Ensure output tensor has correct dimensions
265+
ET_KERNEL_CHECK(
266+
ctx, out.size(0) == in.size(0), InvalidArgument, out); // batch
267+
ET_KERNEL_CHECK(
268+
ctx, out.size(1) == in.size(1), InvalidArgument, out); // channels
269+
ET_KERNEL_CHECK(
270+
ctx, out.size(2) == output_size[0], InvalidArgument, out); // height
271+
ET_KERNEL_CHECK(
272+
ctx, out.size(3) == output_size[1], InvalidArgument, out); // width
273+
274+
// Compute final scales - use provided scales if available, otherwise compute
275+
// from sizes
276+
double final_scale_h, final_scale_w;
277+
if (scale_h.has_value() && scale_w.has_value()) {
278+
final_scale_h = scale_h.value();
279+
final_scale_w = scale_w.value();
280+
} else {
281+
// Compute scales from input/output sizes
282+
final_scale_h =
283+
static_cast<double>(output_size[0]) / static_cast<double>(in.size(2));
284+
final_scale_w =
285+
static_cast<double>(output_size[1]) / static_cast<double>(in.size(3));
286+
}
287+
288+
const auto kernel_scale_h = area_pixel_compute_scale<double>(
289+
in.sizes()[2], out.sizes()[2], align_corners, final_scale_h);
290+
const auto kernel_scale_w = area_pixel_compute_scale<double>(
291+
in.sizes()[3], out.sizes()[3], align_corners, final_scale_w);
292+
293+
ET_SWITCH_REALHBF16_TYPES(
294+
in.scalar_type(), ctx, "_upsample_bilinear2d_aa.out", CTYPE, [&]() {
295+
upsample_bilinear2d_aa_kernel_impl<CTYPE>(
296+
ctx, in, align_corners, kernel_scale_h, kernel_scale_w, out);
297+
});
298+
299+
return out;
300+
}
301+
302+
} // namespace native
303+
} // namespace executor
304+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,11 @@
965965
- arg_meta: null
966966
kernel_name: torch::executor::upsample_bilinear2d_vec_out
967967

968+
- op: _upsample_bilinear2d_aa.out
969+
kernels:
970+
- arg_meta: null
971+
kernel_name: torch::executor::_upsample_bilinear2d_aa_out
972+
968973
- op: upsample_nearest2d.vec_out
969974
kernels:
970975
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ runtime.cxx_library(
2020
deps = [
2121
"//executorch/extension/aten_util:aten_bridge",
2222
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
23+
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d_aa",
2324
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",
2425
"//executorch/runtime/core/exec_aten:lib",
2526
],

0 commit comments

Comments
 (0)