Skip to content

Commit 21b4bca

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add portable upsample_bilinear2d kernel (pytorch#6923)
Summary: Add a upsample_bilinear2d kernel to the portable kernel library. This implementation re-uses some of the inner logic from the ATen implementation (see Upsample.h and UpsampleKernel.cpp), however I have not ported the outer kernel structure as it relies on TensorIterator and runtime allocation. It may be worth re-visiting this in the future, either by looking at pulling in more of the ATen implementation or adding an optimized variant. Test Plan: Added comprehensive operator-level test coverage for upsample_bilinear2d. ``` buck test //executorch/kernels/test:portable_op_upsample_bilinear2d_test buck test //executorch/kernels/test:aten_op_upsample_bilinear2d_test ``` Differential Revision: D65756150 Pulled By: GregoryComer
1 parent 1cf9482 commit 21b4bca

File tree

9 files changed

+939
-0
lines changed

9 files changed

+939
-0
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,8 @@
403403

404404
- op: unsqueeze_copy.out
405405

406+
- op: upsample_bilinear2d.vec_out
407+
406408
- op: upsample_nearest2d.out
407409

408410
- op: upsample_nearest2d.vec_out
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
using exec_aten::ArrayRef;
17+
using exec_aten::optional;
18+
using exec_aten::SizesType;
19+
20+
namespace {
21+
template <typename CTYPE>
22+
void upsample_bilinear2d_kernel_impl(
23+
const Tensor& in,
24+
bool align_corners,
25+
const float scale_h,
26+
const float scale_w,
27+
Tensor& out) {
28+
const auto in_data = in.const_data_ptr<CTYPE>();
29+
auto out_data = out.mutable_data_ptr<CTYPE>();
30+
31+
auto in_plane = in_data;
32+
for (auto n = 0; n < out.size(0); n++) {
33+
for (auto c = 0; c < out.size(1); c++) {
34+
for (auto h = 0; h < out.size(2); h++) {
35+
for (auto w = 0; w < out.size(3); w++) {
36+
// Compute source index.
37+
// See area_pixel_compute_source_index in
38+
// pytorch/aten/src/ATen/native/UpSample.h
39+
int64_t in_h1, in_h2, in_w1, in_w2;
40+
float weight_h, inv_weight_h, weight_w, inv_weight_w;
41+
42+
compute_source_index_and_lambda(
43+
in_h1,
44+
in_h2,
45+
weight_h,
46+
inv_weight_h,
47+
scale_h,
48+
h,
49+
in.sizes()[2],
50+
out.sizes()[2],
51+
align_corners);
52+
53+
compute_source_index_and_lambda(
54+
in_w1,
55+
in_w2,
56+
weight_w,
57+
inv_weight_w,
58+
scale_w,
59+
w,
60+
in.sizes()[3],
61+
out.sizes()[3],
62+
align_corners);
63+
64+
const auto top_left =
65+
in_plane[in_h1 * in.strides()[2] + in_w1 * in.strides()[3]];
66+
const auto top_right =
67+
in_plane[in_h1 * in.strides()[2] + in_w2 * in.strides()[3]];
68+
const auto bottom_left =
69+
in_plane[in_h2 * in.strides()[2] + in_w1 * in.strides()[3]];
70+
const auto bottom_right =
71+
in_plane[in_h2 * in.strides()[2] + in_w2 * in.strides()[3]];
72+
73+
const auto top = top_left * weight_w + top_right * inv_weight_w;
74+
const auto bottom =
75+
bottom_left * weight_w + bottom_right * inv_weight_w;
76+
const auto val = top * weight_h + bottom * inv_weight_h;
77+
78+
*out_data = val;
79+
out_data++;
80+
}
81+
}
82+
83+
in_plane += in.strides()[1];
84+
}
85+
}
86+
}
87+
} // namespace
88+
89+
// Signatures are auto-generated, so disable pass-by-value lint.
90+
// NOLINTBEGIN(facebook-hte-ConstantArgumentPassByValue, facebook-hte-ParameterMightThrowOnCopy)
91+
Tensor& upsample_bilinear2d_vec_out(
92+
KernelRuntimeContext& ctx,
93+
const Tensor& in,
94+
const exec_aten::OptionalArrayRef<int64_t> output_size,
95+
bool align_corners,
96+
const exec_aten::OptionalArrayRef<double> scale_factors,
97+
Tensor& out) {
98+
// Preconditions (checked in check_..._args):
99+
// In and out tensors have same dtype.
100+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
101+
// In and out tensors are default dim order (NCHW).
102+
ET_KERNEL_CHECK(
103+
ctx,
104+
check_upsample_bilinear2d_args(
105+
in, output_size, align_corners, scale_factors, out),
106+
InvalidArgument,
107+
out);
108+
109+
double scale_h, scale_w;
110+
111+
ET_KERNEL_CHECK_MSG(
112+
ctx,
113+
resize_upsample_2d(in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok,
114+
InvalidArgument,
115+
out,
116+
"Failed to resize output tensor");
117+
118+
const auto kernel_scale_h = area_pixel_compute_scale<double>(
119+
in.sizes()[2], out.sizes()[2], align_corners, scale_h);
120+
const auto kernel_scale_w = area_pixel_compute_scale<double>(
121+
in.sizes()[3], out.sizes()[3], align_corners, scale_w);
122+
123+
ET_SWITCH_REAL_TYPES(
124+
in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() {
125+
upsample_bilinear2d_kernel_impl<CTYPE>(
126+
in, align_corners, kernel_scale_h, kernel_scale_w, out);
127+
});
128+
129+
return out;
130+
}
131+
// NOLINTEND(facebook-hte-ConstantArgumentPassByValue, facebook-hte-ParameterMightThrowOnCopy)
132+
133+
} // namespace native
134+
} // namespace executor
135+
} // namespace torch

kernels/portable/cpu/util/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def define_common_targets():
3131
"//executorch/kernels/portable/cpu/util:advanced_index_util",
3232
"//executorch/kernels/portable/cpu/util:slice_util",
3333
"//executorch/kernels/portable/cpu/util:elementwise_util",
34+
"//executorch/kernels/portable/cpu/util:upsample_util",
3435
],
3536
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
3637
)
@@ -266,6 +267,16 @@ def define_common_targets():
266267
visibility = ["//executorch/kernels/portable/cpu/..."],
267268
)
268269

270+
runtime.cxx_library(
271+
name = "upsample_util",
272+
srcs = ["upsample_util.cpp"],
273+
exported_headers = ["upsample_util.h"],
274+
deps = [
275+
"//executorch/runtime/kernel:kernel_includes",
276+
],
277+
visibility = ["//executorch/kernels/portable/cpu/..."],
278+
)
279+
269280
# Utility functions that can be used by operators that perform reduction
270281
for aten_mode in [True, False]:
271282
suffix = "_aten" if aten_mode else ""
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
10+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
15+
bool check_upsample_2d_common_args(
16+
const Tensor& in,
17+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
18+
const exec_aten::OptionalArrayRef<double>& scale_factors,
19+
Tensor& out) {
20+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
21+
ET_LOG_AND_RETURN_IF_FALSE(in.dim() == 4);
22+
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == 4);
23+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(in));
24+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(out));
25+
ET_LOG_AND_RETURN_IF_FALSE(output_size.has_value() ^ scale_factors.has_value());
26+
if (scale_factors.has_value()) {
27+
ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value().size() == 2);
28+
ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value()[0] > 0);
29+
ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value()[1] > 0);
30+
}
31+
else if (output_size.has_value()) {
32+
ET_LOG_AND_RETURN_IF_FALSE(output_size.value().size() == 2);
33+
ET_LOG_AND_RETURN_IF_FALSE(output_size.value()[0] > 0);
34+
ET_LOG_AND_RETURN_IF_FALSE(output_size.value()[1] > 0);
35+
}
36+
37+
return true;
38+
}
39+
40+
bool check_upsample_bilinear2d_args(
41+
const Tensor& in,
42+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
43+
ET_UNUSED const bool align_corners,
44+
const exec_aten::OptionalArrayRef<double>& scale_factors,
45+
Tensor& out) {
46+
return check_upsample_2d_common_args(in, output_size, scale_factors, out);
47+
}
48+
49+
Error resize_upsample_2d(
50+
const Tensor& in,
51+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
52+
const exec_aten::OptionalArrayRef<double>& scale_factors,
53+
double& scale_h_out,
54+
double& scale_w_out,
55+
Tensor& out) {
56+
// Either output_size or scale_factors are provided, not both. This
57+
// is checked in check_..._args.
58+
// Scales are transformed according to align_corners.
59+
std::array<Tensor::SizesType, kTensorDimensionLimit> target_size;
60+
61+
const auto dim = in.dim();
62+
std::copy(in.sizes().cbegin(), in.sizes().cend(), target_size.begin());
63+
64+
if (scale_factors.has_value()) {
65+
scale_h_out = scale_factors.value()[0];
66+
scale_w_out = scale_factors.value()[1];
67+
68+
target_size[dim - 2] =
69+
static_cast<Tensor::SizesType>(in.sizes()[dim - 2] * scale_h_out);
70+
target_size[dim - 1] =
71+
static_cast<Tensor::SizesType>(in.sizes()[dim - 1] * scale_w_out);
72+
} else if (output_size.has_value()) {
73+
scale_h_out = static_cast<double>(output_size.value()[0]) / in.sizes()[dim - 2];
74+
scale_w_out = static_cast<double>(output_size.value()[1]) / in.sizes()[dim - 1];
75+
76+
target_size[dim - 2] = output_size.value()[0];
77+
target_size[dim - 1] = output_size.value()[1];
78+
} else {
79+
ET_LOG(Error, "Invalid output_size or scale_factors");
80+
return Error::InvalidArgument;
81+
}
82+
83+
ET_CHECK_OR_RETURN_ERROR(
84+
target_size[dim - 2] > 0 && target_size[dim - 1] > 0,
85+
InvalidArgument,
86+
"Upsampled output size must be non-empty, but was %ld x %ld.",
87+
static_cast<long>(target_size[dim - 2]),
88+
static_cast<long>(target_size[dim - 1]));
89+
90+
return resize_tensor(out, {target_size.data(), static_cast<size_t>(dim)});
91+
}
92+
93+
} // namespace executor
94+
} // namespace torch
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
bool check_upsample_2d_common_args(
19+
const Tensor& in,
20+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
21+
const exec_aten::OptionalArrayRef<double>& scale_factors,
22+
Tensor& out);
23+
24+
bool check_upsample_bilinear2d_args(
25+
const Tensor& in,
26+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
27+
const bool align_corners,
28+
const exec_aten::OptionalArrayRef<double>& scale_factors,
29+
Tensor& out);
30+
31+
Error resize_upsample_2d(
32+
const Tensor& in,
33+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
34+
const exec_aten::OptionalArrayRef<double>& scale_factors,
35+
double& scale_h_out,
36+
double& scale_w_out,
37+
Tensor& out);
38+
39+
// Ported from aten/src/ATen/native/UpSample.h
40+
template <typename scalar_t>
41+
inline scalar_t compute_scales_value(
42+
const exec_aten::optional<double>& scale,
43+
int64_t input_size,
44+
int64_t output_size) {
45+
return scale.has_value() ? static_cast<scalar_t>(1.0 / scale.value())
46+
: (static_cast<scalar_t>(input_size) / output_size);
47+
}
48+
49+
// Ported from aten/src/ATen/native/UpSample.h
50+
template <typename scalar_t>
51+
inline scalar_t area_pixel_compute_scale(
52+
int64_t input_size,
53+
int64_t output_size,
54+
bool align_corners,
55+
const exec_aten::optional<double>& scale) {
56+
// see Note [area_pixel_compute_scale]
57+
if (align_corners) {
58+
if (output_size > 1) {
59+
return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
60+
} else {
61+
return static_cast<scalar_t>(0);
62+
}
63+
} else {
64+
return compute_scales_value<scalar_t>(scale, input_size, output_size);
65+
}
66+
}
67+
68+
// Ported from aten/src/ATen/native/UpSample.h
69+
template <typename scalar_t>
70+
inline scalar_t area_pixel_compute_source_index(
71+
scalar_t scale,
72+
int64_t dst_index,
73+
bool align_corners,
74+
bool cubic) {
75+
if (align_corners) {
76+
return scale * dst_index;
77+
} else {
78+
scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
79+
static_cast<scalar_t>(0.5);
80+
return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
81+
: src_idx;
82+
}
83+
}
84+
85+
// Ported from aten/src/ATen/native/UpSample.h
86+
// when `real_input_index` becomes larger than the range the floating point
87+
// type can accurately represent, the type casting to `int64_t` might exceed
88+
// `input_size`, causing overflow. So we guard it with `std::min` below.
89+
template <typename scalar_t, typename opmath_t>
90+
inline void guard_index_and_lambda(
91+
const opmath_t& real_input_index,
92+
const int64_t& input_size,
93+
int64_t& input_index,
94+
scalar_t& lambda) {
95+
input_index =
96+
std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
97+
lambda = std::min(
98+
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
99+
static_cast<opmath_t>(1));
100+
}
101+
102+
// Ported from aten/src/ATen/native/UpSample.h
103+
template <typename scalar_t, typename opmath_t>
104+
inline void compute_source_index_and_lambda(
105+
int64_t& input_index0,
106+
int64_t& input_index1,
107+
scalar_t& lambda0,
108+
scalar_t& lambda1,
109+
opmath_t ratio,
110+
int64_t output_index,
111+
int64_t input_size,
112+
int64_t output_size,
113+
bool align_corners) {
114+
if (output_size == input_size) {
115+
// scale_factor = 1, simply copy
116+
input_index0 = output_index;
117+
input_index1 = output_index;
118+
lambda0 = static_cast<scalar_t>(1);
119+
lambda1 = static_cast<scalar_t>(0);
120+
} else {
121+
const auto real_input_index = area_pixel_compute_source_index<opmath_t>(
122+
ratio, output_index, align_corners, /*cubic=*/false);
123+
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
124+
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
125+
input_index1 = input_index0 + offset;
126+
lambda0 = static_cast<scalar_t>(1.) - lambda1;
127+
}
128+
}
129+
130+
} // namespace executor
131+
} // namespace torch

0 commit comments

Comments
 (0)