Skip to content

Commit f46d433

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add portable upsample_bilinear2d kernel (#6923)
Summary: Add a upsample_bilinear2d kernel to the portable kernel library. This implementation re-uses some of the inner logic from the ATen implementation (see Upsample.h and UpsampleKernel.cpp), however I have not ported the outer kernel structure as it relies on TensorIterator and runtime allocation. It may be worth re-visiting this in the future, either by looking at pulling in more of the ATen implementation or adding an optimized variant. Differential Revision: D65756150 Pulled By: GregoryComer
1 parent 54feeef commit f46d433

File tree

8 files changed

+943
-0
lines changed

8 files changed

+943
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <algorithm>
12+
#include <array>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
using exec_aten::ArrayRef;
19+
using exec_aten::optional;
20+
using exec_aten::SizesType;
21+
22+
namespace {
23+
template <typename CTYPE>
24+
void upsample_bilinear2d_kernel_impl(
25+
const Tensor& in,
26+
bool align_corners,
27+
const float scale_h,
28+
const float scale_w,
29+
Tensor& out) {
30+
const auto in_data = in.const_data_ptr<CTYPE>();
31+
auto out_data = out.mutable_data_ptr<CTYPE>();
32+
33+
auto in_plane = in_data;
34+
for (auto n = 0; n < out.size(0); n++) {
35+
for (auto c = 0; c < out.size(1); c++) {
36+
for (auto h = 0; h < out.size(2); h++) {
37+
for (auto w = 0; w < out.size(3); w++) {
38+
// Compute source index.
39+
// See area_pixel_compute_source_index in
40+
// pytorch/aten/src/ATen/native/UpSample.h
41+
int64_t in_h1, in_h2, in_w1, in_w2;
42+
float weight_h, inv_weight_h, weight_w, inv_weight_w;
43+
44+
compute_source_index_and_lambda(
45+
in_h1,
46+
in_h2,
47+
weight_h,
48+
inv_weight_h,
49+
scale_h,
50+
h,
51+
in.sizes()[2],
52+
out.sizes()[2],
53+
align_corners);
54+
55+
compute_source_index_and_lambda(
56+
in_w1,
57+
in_w2,
58+
weight_w,
59+
inv_weight_w,
60+
scale_w,
61+
w,
62+
in.sizes()[3],
63+
out.sizes()[3],
64+
align_corners);
65+
66+
const auto top_left =
67+
in_plane[in_h1 * in.strides()[2] + in_w1 * in.strides()[3]];
68+
const auto top_right =
69+
in_plane[in_h1 * in.strides()[2] + in_w2 * in.strides()[3]];
70+
const auto bottom_left =
71+
in_plane[in_h2 * in.strides()[2] + in_w1 * in.strides()[3]];
72+
const auto bottom_right =
73+
in_plane[in_h2 * in.strides()[2] + in_w2 * in.strides()[3]];
74+
75+
const auto top = top_left * weight_w + top_right * inv_weight_w;
76+
const auto bottom =
77+
bottom_left * weight_w + bottom_right * inv_weight_w;
78+
const auto val = top * weight_h + bottom * inv_weight_h;
79+
80+
*out_data = val;
81+
out_data++;
82+
}
83+
}
84+
85+
in_plane += in.strides()[1];
86+
}
87+
}
88+
}
89+
} // namespace
90+
91+
Tensor& upsample_bilinear2d_out(
92+
KernelRuntimeContext& ctx,
93+
const Tensor& in,
94+
const exec_aten::ArrayRef<int64_t> output_size,
95+
bool align_corners,
96+
const optional<double> scale_h,
97+
const optional<double> scale_w,
98+
Tensor& out) {
99+
// Preconditions (checked in check_..._args):
100+
// In and out tensors have same dtype.
101+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
102+
// In and out tensors are default dim order (NCHW).
103+
ET_KERNEL_CHECK(
104+
ctx,
105+
check_upsample_bilinear2d_args(
106+
in, output_size, align_corners, scale_h, scale_w, out),
107+
InvalidArgument,
108+
out);
109+
110+
ET_KERNEL_CHECK_MSG(
111+
ctx,
112+
resize_upsample_2d(in, output_size, scale_h, scale_w, out) == Error::Ok,
113+
InvalidArgument,
114+
out,
115+
"Failed to resize output tensor");
116+
117+
const auto kernel_scale_h = area_pixel_compute_scale<double>(
118+
in.sizes()[2], out.sizes()[2], align_corners, scale_h);
119+
const auto kernel_scale_w = area_pixel_compute_scale<double>(
120+
in.sizes()[3], out.sizes()[3], align_corners, scale_w);
121+
122+
ET_SWITCH_REAL_TYPES(
123+
in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() {
124+
upsample_bilinear2d_kernel_impl<CTYPE>(
125+
in, align_corners, kernel_scale_h, kernel_scale_w, out);
126+
});
127+
128+
return out;
129+
}
130+
131+
} // namespace native
132+
} // namespace executor
133+
} // namespace torch

kernels/portable/cpu/util/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def define_common_targets():
3131
"//executorch/kernels/portable/cpu/util:advanced_index_util",
3232
"//executorch/kernels/portable/cpu/util:slice_util",
3333
"//executorch/kernels/portable/cpu/util:elementwise_util",
34+
"//executorch/kernels/portable/cpu/util:upsample_util",
3435
],
3536
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
3637
)
@@ -266,6 +267,16 @@ def define_common_targets():
266267
visibility = ["//executorch/kernels/portable/cpu/..."],
267268
)
268269

270+
runtime.cxx_library(
271+
name = "upsample_util",
272+
srcs = ["upsample_util.cpp"],
273+
exported_headers = ["upsample_util.h"],
274+
deps = [
275+
"//executorch/runtime/kernel:kernel_includes",
276+
],
277+
visibility = ["//executorch/kernels/portable/cpu/..."],
278+
)
279+
269280
# Utility functions that can be used by operators that perform reduction
270281
for aten_mode in [True, False]:
271282
suffix = "_aten" if aten_mode else ""
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
10+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
15+
bool check_upsample_2d_common_args(
16+
const Tensor& in,
17+
const exec_aten::ArrayRef<int64_t> output_size,
18+
const exec_aten::optional<double> scale_h,
19+
const exec_aten::optional<double> scale_w,
20+
Tensor& out) {
21+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
22+
ET_LOG_AND_RETURN_IF_FALSE(in.dim() == 4);
23+
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == 4);
24+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(in));
25+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(out));
26+
ET_LOG_AND_RETURN_IF_FALSE(
27+
(output_size.size() == 2 && !scale_h.has_value() &&
28+
!scale_w.has_value()) ||
29+
(output_size.size() == 0 && scale_h.has_value() && scale_w.has_value()));
30+
ET_LOG_AND_RETURN_IF_FALSE(!scale_h.has_value() || scale_h.value() > 0);
31+
ET_LOG_AND_RETURN_IF_FALSE(!scale_w.has_value() || scale_w.value() > 0);
32+
ET_LOG_AND_RETURN_IF_FALSE(output_size.size() < 1 || output_size[0] > 0);
33+
ET_LOG_AND_RETURN_IF_FALSE(output_size.size() < 2 || output_size[1] > 0);
34+
35+
return true;
36+
}
37+
38+
bool check_upsample_bilinear2d_args(
39+
const Tensor& in,
40+
const exec_aten::ArrayRef<int64_t> output_size,
41+
ET_UNUSED const bool align_corners,
42+
const exec_aten::optional<double> scale_h,
43+
const exec_aten::optional<double> scale_w,
44+
Tensor& out) {
45+
return check_upsample_2d_common_args(in, output_size, scale_h, scale_w, out);
46+
}
47+
48+
bool check_upsample_nearest2d_args(
49+
const Tensor& in,
50+
const exec_aten::ArrayRef<int64_t> output_size,
51+
const exec_aten::optional<double> scale_h,
52+
const exec_aten::optional<double> scale_w,
53+
Tensor& out) {
54+
return check_upsample_2d_common_args(in, output_size, scale_h, scale_w, out);
55+
}
56+
57+
Error resize_upsample_2d(
58+
const Tensor& in,
59+
const exec_aten::ArrayRef<int64_t> output_size,
60+
const exec_aten::optional<double> scale_h,
61+
const exec_aten::optional<double> scale_w,
62+
Tensor& out) {
63+
// Either output_size or scale_factors are provided, not both. This
64+
// is checked in check_..._args.
65+
// Scales are transformed according to align_corners.
66+
std::array<Tensor::SizesType, kTensorDimensionLimit> target_size;
67+
68+
const auto dim = in.dim();
69+
std::copy(in.sizes().cbegin(), in.sizes().cend(), target_size.begin());
70+
71+
if (scale_h.has_value() && scale_w.has_value()) {
72+
target_size[dim - 2] =
73+
static_cast<Tensor::SizesType>(in.sizes()[dim - 2] * scale_h.value());
74+
target_size[dim - 1] =
75+
static_cast<Tensor::SizesType>(in.sizes()[dim - 1] * scale_w.value());
76+
} else if (output_size.size() == 2) {
77+
target_size[dim - 2] = output_size[0];
78+
target_size[dim - 1] = output_size[1];
79+
} else {
80+
ET_LOG(Error, "Invalid output_size or scale_factors");
81+
return Error::InvalidArgument;
82+
}
83+
84+
ET_CHECK_OR_RETURN_ERROR(
85+
target_size[dim - 2] > 0 && target_size[dim - 1] > 0,
86+
InvalidArgument,
87+
"Upsampled output size must be non-empty, but was %ld x %ld.",
88+
static_cast<long>(target_size[dim - 2]),
89+
static_cast<long>(target_size[dim - 1]));
90+
91+
return resize_tensor(out, {target_size.data(), static_cast<size_t>(dim)});
92+
}
93+
94+
} // namespace executor
95+
} // namespace torch
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
bool check_upsample_2d_common_args(
19+
const Tensor& in,
20+
const exec_aten::ArrayRef<int64_t> output_size,
21+
const exec_aten::optional<double> scale_h,
22+
const exec_aten::optional<double> scale_w,
23+
Tensor& out);
24+
25+
bool check_upsample_bilinear2d_args(
26+
const Tensor& in,
27+
const exec_aten::ArrayRef<int64_t> output_size,
28+
const bool align_corners,
29+
const exec_aten::optional<double> scale_h,
30+
const exec_aten::optional<double> scale_w,
31+
Tensor& out);
32+
33+
bool check_upsample_nearest2d_args(
34+
const Tensor& in,
35+
const exec_aten::ArrayRef<int64_t> output_size,
36+
const exec_aten::optional<double> scale_h,
37+
const exec_aten::optional<double> scale_w,
38+
Tensor& out);
39+
40+
Error resize_upsample_2d(
41+
const Tensor& in,
42+
const exec_aten::ArrayRef<int64_t> output_size,
43+
const exec_aten::optional<double> scale_h,
44+
const exec_aten::optional<double> scale_w,
45+
Tensor& out);
46+
47+
// Ported from aten/src/ATen/native/UpSample.h
48+
template <typename scalar_t>
49+
inline scalar_t compute_scales_value(
50+
const exec_aten::optional<double> scale,
51+
int64_t input_size,
52+
int64_t output_size) {
53+
return scale.has_value() ? static_cast<scalar_t>(1.0 / scale.value())
54+
: (static_cast<scalar_t>(input_size) / output_size);
55+
}
56+
57+
// Ported from aten/src/ATen/native/UpSample.h
58+
template <typename scalar_t>
59+
inline scalar_t area_pixel_compute_scale(
60+
int64_t input_size,
61+
int64_t output_size,
62+
bool align_corners,
63+
const exec_aten::optional<double> scale) {
64+
// see Note [area_pixel_compute_scale]
65+
if (align_corners) {
66+
if (output_size > 1) {
67+
return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
68+
} else {
69+
return static_cast<scalar_t>(0);
70+
}
71+
} else {
72+
return compute_scales_value<scalar_t>(scale, input_size, output_size);
73+
}
74+
}
75+
76+
// Ported from aten/src/ATen/native/UpSample.h
77+
template <typename scalar_t>
78+
inline scalar_t area_pixel_compute_source_index(
79+
scalar_t scale,
80+
int64_t dst_index,
81+
bool align_corners,
82+
bool cubic) {
83+
if (align_corners) {
84+
return scale * dst_index;
85+
} else {
86+
scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
87+
static_cast<scalar_t>(0.5);
88+
return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
89+
: src_idx;
90+
}
91+
}
92+
93+
// Ported from aten/src/ATen/native/UpSample.h
94+
// when `real_input_index` becomes larger than the range the floating point
95+
// type can accurately represent, the type casting to `int64_t` might exceed
96+
// `input_size`, causing overflow. So we guard it with `std::min` below.
97+
template <typename scalar_t, typename opmath_t>
98+
inline void guard_index_and_lambda(
99+
const opmath_t& real_input_index,
100+
const int64_t& input_size,
101+
int64_t& input_index,
102+
scalar_t& lambda) {
103+
input_index =
104+
std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
105+
lambda = std::min(
106+
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
107+
static_cast<opmath_t>(1));
108+
}
109+
110+
// Ported from aten/src/ATen/native/UpSample.h
111+
template <typename scalar_t, typename opmath_t>
112+
inline void compute_source_index_and_lambda(
113+
int64_t& input_index0,
114+
int64_t& input_index1,
115+
scalar_t& lambda0,
116+
scalar_t& lambda1,
117+
opmath_t ratio,
118+
int64_t output_index,
119+
int64_t input_size,
120+
int64_t output_size,
121+
bool align_corners) {
122+
if (output_size == input_size) {
123+
// scale_factor = 1, simply copy
124+
input_index0 = output_index;
125+
input_index1 = output_index;
126+
lambda0 = static_cast<scalar_t>(1);
127+
lambda1 = static_cast<scalar_t>(0);
128+
} else {
129+
const auto real_input_index = area_pixel_compute_source_index<opmath_t>(
130+
ratio, output_index, align_corners, /*cubic=*/false);
131+
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
132+
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
133+
input_index1 = input_index0 + offset;
134+
lambda0 = static_cast<scalar_t>(1.) - lambda1;
135+
}
136+
}
137+
138+
} // namespace executor
139+
} // namespace torch

0 commit comments

Comments
 (0)