1+ /*
2+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3+ * All rights reserved.
4+ *
5+ * This source code is licensed under the BSD-style license found in the
6+ * LICENSE file in the root directory of this source tree.
7+ */
8+ #include < c10/util/irange.h>
9+ #include < algorithm>
10+ #include < cmath>
11+
12+ #include < executorch/kernels/portable/cpu/util/upsample_util.h>
13+ #include < executorch/runtime/kernel/kernel_includes.h>
14+
15+ namespace torch {
16+ namespace executor {
17+ namespace native {
18+
19+ using executorch::aten::ArrayRef;
20+ using executorch::aten::SizesType;
21+ using std::optional;
22+
23+ namespace {
24+
25+ // Anti-aliasing filter for bilinear interpolation
26+ // Adapted from PyTorch's implementation
27+ template <typename T>
28+ inline T bilinear_aa_filter (T x) {
29+ x = std::abs (x);
30+ if (x < 1.0 ) {
31+ return 1.0 - x;
32+ }
33+ return 0.0 ;
34+ }
35+
36+ // Compute weights and indices for a single output pixel with anti-aliasing
37+ template <typename T>
38+ void compute_aa_weights_for_pixel (
39+ int64_t output_idx,
40+ int64_t input_size,
41+ int64_t output_size,
42+ bool align_corners,
43+ int64_t * indices,
44+ T* weights,
45+ int64_t * num_contributors) {
46+ const T scale = area_pixel_compute_scale<T>(
47+ input_size, output_size, align_corners, optional<double >());
48+
49+ const T center = area_pixel_compute_source_index<T>(
50+ scale, output_idx, align_corners, /* cubic=*/ false );
51+
52+ // Support is the filter radius - for downsampling, we need a larger filter
53+ const T support = (scale >= 1.0 ) ? scale : 1.0 ;
54+
55+ // Find the range of input pixels that contribute
56+ const int64_t xmin = std::max (
57+ static_cast <int64_t >(center - support + 0.5 ), static_cast <int64_t >(0 ));
58+ const int64_t xmax =
59+ std::min (static_cast <int64_t >(center + support + 0.5 ), input_size);
60+
61+ T total_weight = 0.0 ;
62+ *num_contributors = std::min (xmax - xmin, int64_t (4 ));
63+
64+ // Compute weights for contributing pixels
65+ for (int64_t j = 0 ; j < *num_contributors; ++j) {
66+ int64_t x = xmin + j;
67+ T weight = bilinear_aa_filter<T>(
68+ (x - center + 0.5 ) / (scale >= 1.0 ? scale : 1.0 ));
69+ indices[j] = x;
70+ weights[j] = weight;
71+ total_weight += weight;
72+ }
73+
74+ // Normalize weights
75+ if (total_weight > 0 ) {
76+ for (int64_t j = 0 ; j < *num_contributors; ++j) {
77+ weights[j] /= total_weight;
78+ }
79+ }
80+ }
81+
82+ template <typename CTYPE>
83+ void upsample_bilinear2d_aa_kernel_impl (
84+ KernelRuntimeContext& ctx,
85+ const Tensor& in,
86+ bool align_corners,
87+ const float scale_h,
88+ const float scale_w,
89+ Tensor& out) {
90+ const auto in_data = in.const_data_ptr <CTYPE>();
91+ auto out_data = out.mutable_data_ptr <CTYPE>();
92+
93+ const bool is_nchw =
94+ is_contiguous_dim_order (in.dim_order ().data (), in.dim_order ().size ());
95+
96+ if (is_nchw) {
97+ // NCHW layout
98+ for (int64_t n = 0 ; n < out.size (0 ); ++n) {
99+ for (int64_t c = 0 ; c < out.size (1 ); ++c) {
100+ const auto in_plane =
101+ in_data + (n * in.size (1 ) + c) * in.size (2 ) * in.size (3 );
102+ auto out_plane =
103+ out_data + (n * out.size (1 ) + c) * out.size (2 ) * out.size (3 );
104+
105+ for (int64_t oh = 0 ; oh < out.size (2 ); ++oh) {
106+ // Compute height weights for this output row
107+ int64_t h_indices[4 ];
108+ float h_weights[4 ];
109+ int64_t h_num_contributors;
110+ compute_aa_weights_for_pixel<float >(
111+ oh,
112+ in.size (2 ),
113+ out.size (2 ),
114+ align_corners,
115+ h_indices,
116+ h_weights,
117+ &h_num_contributors);
118+
119+ for (int64_t ow = 0 ; ow < out.size (3 ); ++ow) {
120+ // Compute width weights for this output column
121+ int64_t w_indices[4 ];
122+ float w_weights[4 ];
123+ int64_t w_num_contributors;
124+ compute_aa_weights_for_pixel<float >(
125+ ow,
126+ in.size (3 ),
127+ out.size (3 ),
128+ align_corners,
129+ w_indices,
130+ w_weights,
131+ &w_num_contributors);
132+
133+ CTYPE value = 0 ;
134+
135+ // Apply anti-aliased interpolation
136+ for (int64_t ih_idx = 0 ; ih_idx < h_num_contributors; ++ih_idx) {
137+ int64_t ih = h_indices[ih_idx];
138+ float h_weight = h_weights[ih_idx];
139+
140+ for (int64_t iw_idx = 0 ; iw_idx < w_num_contributors; ++iw_idx) {
141+ int64_t iw = w_indices[iw_idx];
142+ float w_weight = w_weights[iw_idx];
143+
144+ value += in_plane[ih * in.size (3 ) + iw] * h_weight * w_weight;
145+ }
146+ }
147+
148+ out_plane[oh * out.size (3 ) + ow] = value;
149+ }
150+ }
151+ }
152+ }
153+ } else {
154+ // NHWC layout
155+ for (int64_t n = 0 ; n < out.size (0 ); ++n) {
156+ const auto in_batch = in_data + n * in.size (1 ) * in.size (2 ) * in.size (3 );
157+ auto out_batch = out_data + n * out.size (1 ) * out.size (2 ) * out.size (3 );
158+
159+ for (int64_t oh = 0 ; oh < out.size (2 ); ++oh) {
160+ // Compute height weights for this output row
161+ int64_t h_indices[4 ];
162+ float h_weights[4 ];
163+ int64_t h_num_contributors;
164+ compute_aa_weights_for_pixel<float >(
165+ oh,
166+ in.size (2 ),
167+ out.size (2 ),
168+ align_corners,
169+ h_indices,
170+ h_weights,
171+ &h_num_contributors);
172+
173+ for (int64_t ow = 0 ; ow < out.size (3 ); ++ow) {
174+ // Compute width weights for this output column
175+ int64_t w_indices[4 ];
176+ float w_weights[4 ];
177+ int64_t w_num_contributors;
178+ compute_aa_weights_for_pixel<float >(
179+ ow,
180+ in.size (3 ),
181+ out.size (3 ),
182+ align_corners,
183+ w_indices,
184+ w_weights,
185+ &w_num_contributors);
186+
187+ for (int64_t c = 0 ; c < out.size (1 ); ++c) {
188+ CTYPE value = 0 ;
189+
190+ // Apply anti-aliased interpolation
191+ for (int64_t ih_idx = 0 ; ih_idx < h_num_contributors; ++ih_idx) {
192+ int64_t ih = h_indices[ih_idx];
193+ float h_weight = h_weights[ih_idx];
194+
195+ for (int64_t iw_idx = 0 ; iw_idx < w_num_contributors; ++iw_idx) {
196+ int64_t iw = w_indices[iw_idx];
197+ float w_weight = w_weights[iw_idx];
198+
199+ value += in_batch[(ih * in.size (3 ) + iw) * in.size (1 ) + c] *
200+ h_weight * w_weight;
201+ }
202+ }
203+
204+ out_batch[(oh * out.size (3 ) + ow) * out.size (1 ) + c] = value;
205+ }
206+ }
207+ }
208+ }
209+ }
210+ }
211+
212+ } // namespace
213+
214+ // Check function for anti-aliased bilinear upsampling
215+ bool check_upsample_bilinear2d_aa_args (
216+ const Tensor& in,
217+ const executorch::aten::OptionalArrayRef<int64_t >& output_size,
218+ const bool align_corners,
219+ const executorch::aten::OptionalArrayRef<double >& scale_factors,
220+ Tensor& out) {
221+ // Use the same checks as regular bilinear upsampling
222+ return check_upsample_bilinear2d_args (
223+ in, output_size, align_corners, scale_factors, out);
224+ }
225+
226+ // Main entry point for anti-aliased bilinear upsampling
227+ Tensor& _upsample_bilinear2d_aa_out (
228+ KernelRuntimeContext& ctx,
229+ const Tensor& in,
230+ const executorch::aten::OptionalArrayRef<int64_t > output_size,
231+ bool align_corners,
232+ const executorch::aten::OptionalArrayRef<double > scale_factors,
233+ Tensor& out) {
234+ // Preconditions (checked in check_..._args):
235+ // In and out tensors have same dtype.
236+ // In and out tensors are rank 4 and have same dim[0] and dim[1].
237+ // In and out tensors are NHWC or NCHW dim order.
238+ ET_KERNEL_CHECK (
239+ ctx,
240+ check_upsample_bilinear2d_aa_args (
241+ in, output_size, align_corners, scale_factors, out),
242+ InvalidArgument,
243+ out);
244+
245+ double scale_h, scale_w;
246+
247+ ET_KERNEL_CHECK_MSG (
248+ ctx,
249+ resize_upsample_2d (
250+ in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok,
251+ InvalidArgument,
252+ out,
253+ " Failed to resize output tensor" );
254+
255+ const auto kernel_scale_h = area_pixel_compute_scale<double >(
256+ in.sizes ()[2 ], out.sizes ()[2 ], align_corners, scale_h);
257+ const auto kernel_scale_w = area_pixel_compute_scale<double >(
258+ in.sizes ()[3 ], out.sizes ()[3 ], align_corners, scale_w);
259+
260+ ET_SWITCH_REALHBF16_TYPES (
261+ in.scalar_type (), ctx, " _upsample_bilinear2d_aa.out" , CTYPE, [&]() {
262+ upsample_bilinear2d_aa_kernel_impl<CTYPE>(
263+ ctx, in, align_corners, kernel_scale_h, kernel_scale_w, out);
264+ });
265+
266+ return out;
267+ }
268+
269+ } // namespace native
270+ } // namespace executor
271+ } // namespace torch
0 commit comments