Skip to content

Commit 8959bb5

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Add aten::_upsample_bilinear2d_aa.out (#13458)
Summary: Trying to resolve #7031 Vibe-coded using the existing non-alias version in ET and Aten implementation in pytorch as reference, along with reference unittests in pytorch core Test Plan: 1. Run https://gist.github.com/mergennachin/9b02aee4feb5acc83e71d8f902f5cca1 And then call `./cmake-out/executor_runner minicpmv_preprocessor.pte` 2. https://gist.github.com/mergennachin/a24e4509804de99caf906c9b79ea45fc Reviewed By: manuelcandales Differential Revision: D80343748 Pulled By: mergennachin
1 parent fff2090 commit 8959bb5

File tree

11 files changed

+1344
-0
lines changed

11 files changed

+1344
-0
lines changed
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <c10/util/irange.h>
9+
#include <algorithm>
10+
#include <cmath>
11+
12+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using executorch::aten::ArrayRef;
20+
using executorch::aten::SizesType;
21+
using std::optional;
22+
23+
namespace {
24+
25+
// Anti-aliasing filter matching PyTorch's implementation exactly
26+
template <typename T>
27+
inline T bilinear_aa_filter(T x) {
28+
x = std::abs(x);
29+
return (x < static_cast<T>(1.0)) ? (static_cast<T>(1.0) - x)
30+
: static_cast<T>(0.0);
31+
}
32+
33+
// Compute anti-aliasing weights exactly matching PyTorch's algorithm
34+
template <typename T>
35+
void compute_aa_weights_for_pixel(
36+
int64_t output_idx,
37+
int64_t input_size,
38+
int64_t output_size,
39+
bool align_corners,
40+
int64_t* indices,
41+
T* weights,
42+
int64_t* num_contributors) {
43+
// Special handling for align_corners=True: corner pixels map exactly to input
44+
// corners
45+
if (align_corners && output_size > 1) {
46+
if (output_idx == 0) {
47+
// First output pixel maps exactly to first input pixel
48+
indices[0] = 0;
49+
weights[0] = static_cast<T>(1.0);
50+
*num_contributors = 1;
51+
for (int64_t j = 1; j < 4; ++j)
52+
weights[j] = static_cast<T>(0.0);
53+
return;
54+
} else if (output_idx == output_size - 1) {
55+
// Last output pixel maps exactly to last input pixel
56+
indices[0] = input_size - 1;
57+
weights[0] = static_cast<T>(1.0);
58+
*num_contributors = 1;
59+
for (int64_t j = 1; j < 4; ++j)
60+
weights[j] = static_cast<T>(0.0);
61+
return;
62+
}
63+
}
64+
65+
const T scale = area_pixel_compute_scale<T>(
66+
input_size, output_size, align_corners, optional<double>());
67+
68+
// PyTorch's exact center calculation for anti-aliasing
69+
// Always uses scale * (i + 0.5) regardless of align_corners for AA
70+
const T center = scale * (output_idx + static_cast<T>(0.5));
71+
72+
// PyTorch's support calculation for bilinear anti-aliasing
73+
// interp_size = 2 for bilinear, so base support = 1.0
74+
const T support =
75+
(scale >= static_cast<T>(1.0)) ? scale : static_cast<T>(1.0);
76+
77+
// PyTorch's exact range calculation
78+
const int64_t xmin = std::max(
79+
static_cast<int64_t>(center - support + static_cast<T>(0.5)),
80+
static_cast<int64_t>(0));
81+
const int64_t xmax = std::min(
82+
static_cast<int64_t>(center + support + static_cast<T>(0.5)), input_size);
83+
84+
*num_contributors = std::min(xmax - xmin, static_cast<int64_t>(4));
85+
86+
// PyTorch's weight computation
87+
T total_weight = static_cast<T>(0.0);
88+
const T invscale = (scale >= static_cast<T>(1.0))
89+
? (static_cast<T>(1.0) / scale)
90+
: static_cast<T>(1.0);
91+
92+
for (int64_t j = 0; j < *num_contributors; ++j) {
93+
int64_t x = xmin + j;
94+
// PyTorch's exact weight formula: (j + xmin - center + 0.5) * invscale
95+
T arg = (static_cast<T>(j) + static_cast<T>(xmin) - center +
96+
static_cast<T>(0.5)) *
97+
invscale;
98+
T weight = bilinear_aa_filter<T>(arg);
99+
indices[j] = x;
100+
weights[j] = weight;
101+
total_weight += weight;
102+
}
103+
104+
// Normalize weights to sum to 1 (PyTorch does this)
105+
if (total_weight > static_cast<T>(0.0)) {
106+
for (int64_t j = 0; j < *num_contributors; ++j) {
107+
weights[j] /= total_weight;
108+
}
109+
}
110+
111+
// Clear unused weight slots
112+
for (int64_t j = *num_contributors; j < 4; ++j) {
113+
weights[j] = static_cast<T>(0.0);
114+
}
115+
}
116+
117+
template <typename CTYPE>
118+
void upsample_bilinear2d_aa_kernel_impl(
119+
KernelRuntimeContext& ctx,
120+
const Tensor& in,
121+
bool align_corners,
122+
const float scale_h,
123+
const float scale_w,
124+
Tensor& out) {
125+
const auto in_data = in.const_data_ptr<CTYPE>();
126+
auto out_data = out.mutable_data_ptr<CTYPE>();
127+
128+
const bool is_nchw =
129+
is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size());
130+
131+
if (is_nchw) {
132+
// NCHW layout
133+
for (int64_t n = 0; n < out.size(0); ++n) {
134+
for (int64_t c = 0; c < out.size(1); ++c) {
135+
const auto in_plane =
136+
in_data + (n * in.size(1) + c) * in.size(2) * in.size(3);
137+
auto out_plane =
138+
out_data + (n * out.size(1) + c) * out.size(2) * out.size(3);
139+
140+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
141+
// Compute height weights for this output row
142+
int64_t h_indices[4];
143+
float h_weights[4];
144+
int64_t h_num_contributors;
145+
compute_aa_weights_for_pixel<float>(
146+
oh,
147+
in.size(2),
148+
out.size(2),
149+
align_corners,
150+
h_indices,
151+
h_weights,
152+
&h_num_contributors);
153+
154+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
155+
// Compute width weights for this output column
156+
int64_t w_indices[4];
157+
float w_weights[4];
158+
int64_t w_num_contributors;
159+
compute_aa_weights_for_pixel<float>(
160+
ow,
161+
in.size(3),
162+
out.size(3),
163+
align_corners,
164+
w_indices,
165+
w_weights,
166+
&w_num_contributors);
167+
168+
CTYPE value = 0;
169+
170+
// Apply anti-aliased interpolation
171+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
172+
int64_t ih = h_indices[ih_idx];
173+
float h_weight = h_weights[ih_idx];
174+
175+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
176+
int64_t iw = w_indices[iw_idx];
177+
float w_weight = w_weights[iw_idx];
178+
179+
value += in_plane[ih * in.size(3) + iw] * h_weight * w_weight;
180+
}
181+
}
182+
183+
out_plane[oh * out.size(3) + ow] = value;
184+
}
185+
}
186+
}
187+
}
188+
} else {
189+
// NHWC layout
190+
for (int64_t n = 0; n < out.size(0); ++n) {
191+
const auto in_batch = in_data + n * in.size(1) * in.size(2) * in.size(3);
192+
auto out_batch = out_data + n * out.size(1) * out.size(2) * out.size(3);
193+
194+
for (int64_t oh = 0; oh < out.size(2); ++oh) {
195+
// Compute height weights for this output row
196+
int64_t h_indices[4];
197+
float h_weights[4];
198+
int64_t h_num_contributors;
199+
compute_aa_weights_for_pixel<float>(
200+
oh,
201+
in.size(2),
202+
out.size(2),
203+
align_corners,
204+
h_indices,
205+
h_weights,
206+
&h_num_contributors);
207+
208+
for (int64_t ow = 0; ow < out.size(3); ++ow) {
209+
// Compute width weights for this output column
210+
int64_t w_indices[4];
211+
float w_weights[4];
212+
int64_t w_num_contributors;
213+
compute_aa_weights_for_pixel<float>(
214+
ow,
215+
in.size(3),
216+
out.size(3),
217+
align_corners,
218+
w_indices,
219+
w_weights,
220+
&w_num_contributors);
221+
222+
for (int64_t c = 0; c < out.size(1); ++c) {
223+
CTYPE value = 0;
224+
225+
// Apply anti-aliased interpolation
226+
for (int64_t ih_idx = 0; ih_idx < h_num_contributors; ++ih_idx) {
227+
int64_t ih = h_indices[ih_idx];
228+
float h_weight = h_weights[ih_idx];
229+
230+
for (int64_t iw_idx = 0; iw_idx < w_num_contributors; ++iw_idx) {
231+
int64_t iw = w_indices[iw_idx];
232+
float w_weight = w_weights[iw_idx];
233+
234+
value += in_batch[(ih * in.size(3) + iw) * in.size(1) + c] *
235+
h_weight * w_weight;
236+
}
237+
}
238+
239+
out_batch[(oh * out.size(3) + ow) * out.size(1) + c] = value;
240+
}
241+
}
242+
}
243+
}
244+
}
245+
}
246+
247+
} // namespace
248+
249+
// Check function for anti-aliased bilinear upsampling
250+
bool check_upsample_bilinear2d_aa_args(
251+
const Tensor& in,
252+
const executorch::aten::OptionalArrayRef<int64_t>& output_size,
253+
const bool align_corners,
254+
const executorch::aten::OptionalArrayRef<double>& scale_factors,
255+
Tensor& out) {
256+
// Use the same checks as regular bilinear upsampling
257+
return check_upsample_bilinear2d_args(
258+
in, output_size, align_corners, scale_factors, out);
259+
}
260+
261+
// Main entry point for anti-aliased bilinear upsampling
262+
Tensor& _upsample_bilinear2d_aa_out(
263+
KernelRuntimeContext& ctx,
264+
const Tensor& in,
265+
const executorch::aten::ArrayRef<int64_t> output_size,
266+
bool align_corners,
267+
const std::optional<double> scale_h,
268+
const std::optional<double> scale_w,
269+
Tensor& out) {
270+
// Preconditions (checked in check_..._args):
271+
// In and out tensors have same dtype.
272+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
273+
// In and out tensors are NHWC or NCHW dim order.
274+
275+
// Custom validation for our specific interface (ArrayRef + optional
276+
// individual scales)
277+
ET_KERNEL_CHECK(ctx, in.dim() == 4, InvalidArgument, out);
278+
ET_KERNEL_CHECK(ctx, out.dim() == 4, InvalidArgument, out);
279+
ET_KERNEL_CHECK(
280+
ctx, in.scalar_type() == out.scalar_type(), InvalidArgument, out);
281+
ET_KERNEL_CHECK(ctx, output_size.size() == 2, InvalidArgument, out);
282+
ET_KERNEL_CHECK(
283+
ctx, output_size[0] > 0 && output_size[1] > 0, InvalidArgument, out);
284+
285+
// Ensure output tensor has correct dimensions
286+
ET_KERNEL_CHECK(
287+
ctx, out.size(0) == in.size(0), InvalidArgument, out); // batch
288+
ET_KERNEL_CHECK(
289+
ctx, out.size(1) == in.size(1), InvalidArgument, out); // channels
290+
ET_KERNEL_CHECK(
291+
ctx, out.size(2) == output_size[0], InvalidArgument, out); // height
292+
ET_KERNEL_CHECK(
293+
ctx, out.size(3) == output_size[1], InvalidArgument, out); // width
294+
295+
// Compute final scales - use provided scales if available, otherwise compute
296+
// from sizes
297+
double final_scale_h, final_scale_w;
298+
if (scale_h.has_value() && scale_w.has_value()) {
299+
final_scale_h = scale_h.value();
300+
final_scale_w = scale_w.value();
301+
} else {
302+
// Compute scales from input/output sizes
303+
final_scale_h =
304+
static_cast<double>(output_size[0]) / static_cast<double>(in.size(2));
305+
final_scale_w =
306+
static_cast<double>(output_size[1]) / static_cast<double>(in.size(3));
307+
}
308+
309+
const auto kernel_scale_h = area_pixel_compute_scale<double>(
310+
in.sizes()[2], out.sizes()[2], align_corners, final_scale_h);
311+
const auto kernel_scale_w = area_pixel_compute_scale<double>(
312+
in.sizes()[3], out.sizes()[3], align_corners, final_scale_w);
313+
314+
ET_SWITCH_REALHBF16_TYPES(
315+
in.scalar_type(), ctx, "_upsample_bilinear2d_aa.out", CTYPE, [&]() {
316+
upsample_bilinear2d_aa_kernel_impl<CTYPE>(
317+
ctx, in, align_corners, kernel_scale_h, kernel_scale_w, out);
318+
});
319+
320+
return out;
321+
}
322+
323+
} // namespace native
324+
} // namespace executor
325+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,11 @@
965965
- arg_meta: null
966966
kernel_name: torch::executor::upsample_bilinear2d_vec_out
967967

968+
- op: _upsample_bilinear2d_aa.out
969+
kernels:
970+
- arg_meta: null
971+
kernel_name: torch::executor::_upsample_bilinear2d_aa_out
972+
968973
- op: upsample_nearest2d.vec_out
969974
kernels:
970975
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ runtime.cxx_library(
2020
deps = [
2121
"//executorch/extension/aten_util:aten_bridge",
2222
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
23+
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d_aa",
2324
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",
2425
"//executorch/runtime/core/exec_aten:lib",
2526
],

0 commit comments

Comments
 (0)