Skip to content

Commit ef82d04

Browse files
committed
Add aten::_upsample_bilinear2d_aa.out (#13458)
Summary: Trying to resolve #7031 Vibe-coded using the existing non-alias version in ET and Aten implementation in pytorch as reference, along with reference unittests in pytorch core Test Plan: 1. Run https://gist.github.com/mergennachin/9b02aee4feb5acc83e71d8f902f5cca1 And then call `./cmake-out/executor_runner minicpmv_preprocessor.pte` 2. https://gist.github.com/mergennachin/a24e4509804de99caf906c9b79ea45fc Reviewed By: manuelcandales Differential Revision: D80343748 Pulled By: mergennachin
1 parent bc5186c commit ef82d04

File tree

11 files changed

+1286
-0
lines changed

11 files changed

+1286
-0
lines changed
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <c10/util/irange.h>
9+
#include <algorithm>
10+
#include <cmath>
11+
12+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using executorch::aten::ArrayRef;
20+
using executorch::aten::SizesType;
21+
using std::optional;
22+
23+
namespace {
24+
25+
// Anti-aliasing filter matching PyTorch's implementation exactly
26+
template <typename T>
27+
inline T bilinear_aa_filter(T x) {
28+
x = std::abs(x);
29+
return (x < static_cast<T>(1.0)) ? (static_cast<T>(1.0) - x)
30+
: static_cast<T>(0.0);
31+
}
32+
33+
// Compute anti-aliasing weights exactly matching PyTorch's algorithm
34+
template <typename T>
35+
void compute_aa_weights_for_pixel(
36+
int64_t output_idx,
37+
int64_t input_size,
38+
int64_t output_size,
39+
bool align_corners,
40+
int64_t* indices,
41+
T* weights,
42+
int64_t* num_contributors) {
43+
const T scale = area_pixel_compute_scale<T>(
44+
input_size, output_size, align_corners, optional<double>());
45+
46+
// PyTorch's center calculation for anti-aliasing
47+
// Always uses scale * (i + 0.5) for anti-aliasing, regardless of
48+
// align_corners
49+
const T center = scale * (output_idx + static_cast<T>(0.5));
50+
51+
// PyTorch's support calculation for bilinear anti-aliasing
52+
// interp_size = 2 for bilinear, so base support = 1.0
53+
const T support = (scale >= static_cast<T>(1.0))
54+
? (static_cast<T>(1.0) * scale)
55+
: static_cast<T>(1.0);
56+
57+
// PyTorch's exact range calculation
58+
const int64_t xmin = std::max(
59+
static_cast<int64_t>(center - support + static_cast<T>(0.5)),
60+
static_cast<int64_t>(0));
61+
const int64_t xmax = std::min(
62+
static_cast<int64_t>(center + support + static_cast<T>(0.5)), input_size);
63+
64+
*num_contributors = std::min(xmax - xmin, static_cast<int64_t>(4));
65+
66+
// PyTorch's weight computation
67+
T total_weight = static_cast<T>(0.0);
68+
const T invscale = (scale >= static_cast<T>(1.0))
69+
? (static_cast<T>(1.0) / scale)
70+
: static_cast<T>(1.0);
71+
72+
for (int64_t j = 0; j < *num_contributors; ++j) {
73+
int64_t x = xmin + j;
74+
// PyTorch's exact weight formula: (j + xmin - center + 0.5) * invscale
75+
T arg = (static_cast<T>(j) + static_cast<T>(xmin) - center +
76+
static_cast<T>(0.5)) *
77+
invscale;
78+
T weight = bilinear_aa_filter<T>(arg);
79+
indices[j] = x;
80+
weights[j] = weight;
81+
total_weight += weight;
82+
}
83+
84+
// Normalize weights to sum to 1 (PyTorch does this)
85+
if (total_weight > static_cast<T>(0.0)) {
86+
for (int64_t j = 0; j < *num_contributors; ++j) {
87+
weights[j] /= total_weight;
88+
}
89+
}
90+
91+
// Clear unused weight slots
92+
for (int64_t j = *num_contributors; j < 4; ++j) {
93+
weights[j] = static_cast<T>(0.0);
94+
}
95+
}
96+
97+
template <typename CTYPE>
98+
void upsample_bilinear2d_aa_kernel_impl(
99+
KernelRuntimeContext& ctx,
100+
const Tensor& in,
101+
bool align_corners,
102+
const float scale_h,
103+
const float scale_w,
104+
Tensor& out) {
105+
const auto in_data = in.const_data_ptr<CTYPE>();
106+
auto out_data = out.mutable_data_ptr<CTYPE>();
107+
108+
const bool is_nchw =
109+
is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size());
110+
111+
if (is_nchw) {
112+
// NCHW layout
113+
for (int64_t n = 0; n < out.size(0); ++n) {
114+
for (int64_t c = 0; c < out.size(1); ++c) {
115+
const auto in_plane =
116+
in_data + (n * in.size(1) + c) * in.size(2) * in.size(3);
117+
auto out_plane =
118+
out_data + (n * out.size(1) + c) * out.size(2) * out.size(3);
119+
120+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
121+
// Compute height weights for this output row
122+
int64_t h_indices[4];
123+
float h_weights[4];
124+
int64_t h_num_contributors;
125+
compute_aa_weights_for_pixel<float>(
126+
oh,
127+
in.size(2),
128+
out.size(2),
129+
align_corners,
130+
h_indices,
131+
h_weights,
132+
&h_num_contributors);
133+
134+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
135+
// Compute width weights for this output column
136+
int64_t w_indices[4];
137+
float w_weights[4];
138+
int64_t w_num_contributors;
139+
compute_aa_weights_for_pixel<float>(
140+
ow,
141+
in.size(3),
142+
out.size(3),
143+
align_corners,
144+
w_indices,
145+
w_weights,
146+
&w_num_contributors);
147+
148+
CTYPE value = 0;
149+
150+
// Apply anti-aliased interpolation
151+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
152+
int64_t ih = h_indices[ih_idx];
153+
float h_weight = h_weights[ih_idx];
154+
155+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
156+
int64_t iw = w_indices[iw_idx];
157+
float w_weight = w_weights[iw_idx];
158+
159+
value += in_plane[ih * in.size(3) + iw] * h_weight * w_weight;
160+
}
161+
}
162+
163+
out_plane[oh * out.size(3) + ow] = value;
164+
}
165+
}
166+
}
167+
}
168+
} else {
169+
// NHWC layout
170+
for (int64_t n = 0; n < out.size(0); ++n) {
171+
const auto in_batch = in_data + n * in.size(1) * in.size(2) * in.size(3);
172+
auto out_batch = out_data + n * out.size(1) * out.size(2) * out.size(3);
173+
174+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
175+
// Compute height weights for this output row
176+
int64_t h_indices[4];
177+
float h_weights[4];
178+
int64_t h_num_contributors;
179+
compute_aa_weights_for_pixel<float>(
180+
oh,
181+
in.size(2),
182+
out.size(2),
183+
align_corners,
184+
h_indices,
185+
h_weights,
186+
&h_num_contributors);
187+
188+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
189+
// Compute width weights for this output column
190+
int64_t w_indices[4];
191+
float w_weights[4];
192+
int64_t w_num_contributors;
193+
compute_aa_weights_for_pixel<float>(
194+
ow,
195+
in.size(3),
196+
out.size(3),
197+
align_corners,
198+
w_indices,
199+
w_weights,
200+
&w_num_contributors);
201+
202+
for (int64_t c = 0; c < out.size(1); ++c) {
203+
CTYPE value = 0;
204+
205+
// Apply anti-aliased interpolation
206+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
207+
int64_t ih = h_indices[ih_idx];
208+
float h_weight = h_weights[ih_idx];
209+
210+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
211+
int64_t iw = w_indices[iw_idx];
212+
float w_weight = w_weights[iw_idx];
213+
214+
value += in_batch[(ih * in.size(3) + iw) * in.size(1) + c] *
215+
h_weight * w_weight;
216+
}
217+
}
218+
219+
out_batch[(oh * out.size(3) + ow) * out.size(1) + c] = value;
220+
}
221+
}
222+
}
223+
}
224+
}
225+
}
226+
227+
} // namespace
228+
229+
// Check function for anti-aliased bilinear upsampling
230+
bool check_upsample_bilinear2d_aa_args(
231+
const Tensor& in,
232+
const executorch::aten::OptionalArrayRef<int64_t>& output_size,
233+
const bool align_corners,
234+
const executorch::aten::OptionalArrayRef<double>& scale_factors,
235+
Tensor& out) {
236+
// Use the same checks as regular bilinear upsampling
237+
return check_upsample_bilinear2d_args(
238+
in, output_size, align_corners, scale_factors, out);
239+
}
240+
241+
// Main entry point for anti-aliased bilinear upsampling
242+
Tensor& _upsample_bilinear2d_aa_out(
243+
KernelRuntimeContext& ctx,
244+
const Tensor& in,
245+
const executorch::aten::ArrayRef<int64_t> output_size,
246+
bool align_corners,
247+
const std::optional<double> scale_h,
248+
const std::optional<double> scale_w,
249+
Tensor& out) {
250+
// Preconditions (checked in check_..._args):
251+
// In and out tensors have same dtype.
252+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
253+
// In and out tensors are NHWC or NCHW dim order.
254+
255+
// Custom validation for our specific interface (ArrayRef + optional
256+
// individual scales)
257+
ET_KERNEL_CHECK(ctx, in.dim() == 4, InvalidArgument, out);
258+
ET_KERNEL_CHECK(ctx, out.dim() == 4, InvalidArgument, out);
259+
ET_KERNEL_CHECK(
260+
ctx, in.scalar_type() == out.scalar_type(), InvalidArgument, out);
261+
ET_KERNEL_CHECK(ctx, output_size.size() == 2, InvalidArgument, out);
262+
ET_KERNEL_CHECK(
263+
ctx, output_size[0] > 0 && output_size[1] > 0, InvalidArgument, out);
264+
265+
// Ensure output tensor has correct dimensions
266+
ET_KERNEL_CHECK(
267+
ctx, out.size(0) == in.size(0), InvalidArgument, out); // batch
268+
ET_KERNEL_CHECK(
269+
ctx, out.size(1) == in.size(1), InvalidArgument, out); // channels
270+
ET_KERNEL_CHECK(
271+
ctx, out.size(2) == output_size[0], InvalidArgument, out); // height
272+
ET_KERNEL_CHECK(
273+
ctx, out.size(3) == output_size[1], InvalidArgument, out); // width
274+
275+
// Compute final scales - use provided scales if available, otherwise compute
276+
// from sizes
277+
double final_scale_h, final_scale_w;
278+
if (scale_h.has_value() && scale_w.has_value()) {
279+
final_scale_h = scale_h.value();
280+
final_scale_w = scale_w.value();
281+
} else {
282+
// Compute scales from input/output sizes
283+
final_scale_h =
284+
static_cast<double>(output_size[0]) / static_cast<double>(in.size(2));
285+
final_scale_w =
286+
static_cast<double>(output_size[1]) / static_cast<double>(in.size(3));
287+
}
288+
289+
const auto kernel_scale_h = area_pixel_compute_scale<double>(
290+
in.sizes()[2], out.sizes()[2], align_corners, final_scale_h);
291+
const auto kernel_scale_w = area_pixel_compute_scale<double>(
292+
in.sizes()[3], out.sizes()[3], align_corners, final_scale_w);
293+
294+
ET_SWITCH_REALHBF16_TYPES(
295+
in.scalar_type(), ctx, "_upsample_bilinear2d_aa.out", CTYPE, [&]() {
296+
upsample_bilinear2d_aa_kernel_impl<CTYPE>(
297+
ctx, in, align_corners, kernel_scale_h, kernel_scale_w, out);
298+
});
299+
300+
return out;
301+
}
302+
303+
} // namespace native
304+
} // namespace executor
305+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,11 @@
965965
- arg_meta: null
966966
kernel_name: torch::executor::upsample_bilinear2d_vec_out
967967

968+
- op: _upsample_bilinear2d_aa.out
969+
kernels:
970+
- arg_meta: null
971+
kernel_name: torch::executor::_upsample_bilinear2d_aa_out
972+
968973
- op: upsample_nearest2d.vec_out
969974
kernels:
970975
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ runtime.cxx_library(
2020
deps = [
2121
"//executorch/extension/aten_util:aten_bridge",
2222
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
23+
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d_aa",
2324
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",
2425
"//executorch/runtime/core/exec_aten:lib",
2526
],

0 commit comments

Comments
 (0)