Skip to content

Commit 058a5df

Browse files
committed
[ET-VK] Add support for aten::upsample_bilinear2d ATen op
Title says it all! Differential Revision: [D73261394](https://our.internmc.facebook.com/intern/diff/D73261394/) [ghstack-poisoned]
1 parent 5b7f235 commit 058a5df

File tree

6 files changed

+163
-92
lines changed

6 files changed

+163
-92
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def register_view_op(features: OpFeatures):
540540
exir_ops.edge.aten.ones.default,
541541
exir_ops.edge.aten.ones_like.default,
542542
exir_ops.edge.aten.upsample_nearest2d.vec,
543+
exir_ops.edge.aten.upsample_bilinear2d.vec,
543544
exir_ops.edge.aten.zeros.default,
544545
exir_ops.edge.aten.zeros_like.default,
545546
exir_ops.edge.et_vk.grid_priors.default,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
15+
layout(std430) buffer;
16+
17+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
18+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
19+
${layout_declare_ubo(B, "ivec3", "out_limits")}
20+
${layout_declare_ubo(B, "ivec3", "in_limits")}
21+
${layout_declare_ubo(B, "vec2", "recip_scales")}
22+
23+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
24+
25+
void main() {
26+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
27+
28+
if (any(greaterThanEqual(pos, out_limits))) {
29+
return;
30+
}
31+
32+
ivec2 max_in_xy = in_limits.xy - 1;
33+
vec2 scaled_xy = pos.xy * recip_scales;
34+
35+
$if MODE == "nearest":
36+
const ivec2 ipos = clamp(ivec2(round(scaled_xy)), ivec2(0), max_in_xy);
37+
VEC4_T out_tex = texelFetch(t_in, ivec3(ipos, pos.z), 0);
38+
$elif MODE == "bilinear":
39+
vec2 upper_xy = ceil(scaled_xy);
40+
vec2 lower_xy = floor(scaled_xy);
41+
42+
// Clamp coordinates to valid input range
43+
upper_xy = clamp(upper_xy, ivec2(0), max_in_xy);
44+
lower_xy = clamp(lower_xy, ivec2(0), max_in_xy);
45+
46+
// Calculate interpolation weights
47+
vec2 interp_weights = (scaled_xy - lower_xy);
48+
49+
// Sample the four nearest texels
50+
VEC4_T sample00 = texelFetch(t_in, ivec3(lower_xy.x, lower_xy.y, pos.z), 0);
51+
VEC4_T sample10 = texelFetch(t_in, ivec3(upper_xy.x, lower_xy.y, pos.z), 0);
52+
VEC4_T sample01 = texelFetch(t_in, ivec3(lower_xy.x, upper_xy.y, pos.z), 0);
53+
VEC4_T sample11 = texelFetch(t_in, ivec3(upper_xy.x, upper_xy.y, pos.z), 0);
54+
55+
// Perform bilinear interpolation
56+
VEC4_T out_tex = mix(
57+
mix(sample00, sample10, interp_weights.x),
58+
mix(sample01, sample11, interp_weights.x),
59+
interp_weights.y
60+
);
61+
// VEC4_T out_tex = VEC4_T(interp_weights.y);
62+
63+
imageStore(t_out, pos, out_tex);
64+
}

backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/upsample_2d.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
upsample_nearest2d:
7+
upsample_2d:
88
parameter_names_with_default_values:
9-
NDIM: 3
109
DTYPE: float
11-
PACKING: C_packed
1210
STORAGE: texture3d
11+
MODE: nearest
1312
generate_variant_forall:
1413
DTYPE:
1514
- VALUE: half
1615
- VALUE: float
1716
shader_variants:
1817
- NAME: upsample_nearest2d
18+
- NAME: upsample_bilinear2d
19+
MODE: bilinear

backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.glsl

Lines changed: 0 additions & 39 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/impl/Upsample.cpp

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
namespace vkcompute {
1818

19+
enum class UpsampleMode : int { NEAREST, BILINEAR };
20+
1921
void resize_upsample_nearest2d_node(
2022
ComputeGraph* graph,
2123
const std::vector<ArgGroup>& args,
@@ -39,19 +41,12 @@ void resize_upsample_nearest2d_node(
3941
out->virtual_resize(out_sizes);
4042
}
4143

42-
// ExecuTorch-Vulkan framework to add node
43-
// Args:
44-
// in: will be converted from NCHW input tensor to 3D ARGB representation in
45-
// openGL (via ExecuTorch) output_sizes: optional 2D array of targetting
46-
// output size of H and W dimensions. >= input sizes;
47-
48-
// will be computed if only given the scale_factors.
49-
// scale_factors: optional 2D array of scale factors for H and W dimensions.
50-
// Will be computed if only given the output_sizes.
5144
void add_upsample_nearest2d_node(
5245
ComputeGraph& graph,
46+
const UpsampleMode mode,
5347
const ValueRef in,
5448
const ValueRef output_sizes,
49+
const ValueRef align_corners,
5550
const ValueRef scale_factors,
5651
const ValueRef out) {
5752
if (graph.val_is_none(output_sizes) && graph.val_is_none(scale_factors)) {
@@ -62,37 +57,51 @@ void add_upsample_nearest2d_node(
6257
VK_THROW(
6358
"Invalid input, must provide ONLY one of output_sizes or scale_factors");
6459
}
60+
utils::uvec3 in_limits = graph.logical_limits_of(in);
61+
utils::uvec3 out_limits = graph.logical_limits_of(out);
6562

66-
vTensorPtr t_in = graph.get_tensor(in);
67-
utils::uvec3 input_sizes = t_in->logical_limits();
63+
uint32_t out_width = out_limits[0u];
64+
uint32_t out_height = out_limits[1u];
6865

69-
utils::ivec2 input_size = {
70-
utils::safe_downcast<int32_t>(input_sizes[0]),
71-
utils::safe_downcast<int32_t>(input_sizes[1])};
72-
utils::vec2 rev_scales = {
73-
utils::safe_downcast<float>(1.0), utils::safe_downcast<float>(1.0)};
66+
float scale_factor_x = float(in_limits[0u]) / float(out_width);
67+
float scale_factor_y = float(in_limits[1u]) / float(out_height);
68+
69+
float recip_scale_factor_x = 1.0f / scale_factor_x;
70+
float recip_scale_factor_y = 1.0f / scale_factor_y;
7471

75-
// Reverse scale factors that pre-computed before GLSL.
7672
if (!graph.val_is_none(output_sizes)) {
77-
auto output_size_ref = graph.get_int_list(output_sizes);
78-
rev_scales = {
79-
utils::safe_downcast<float>(
80-
(float)input_size[0] / output_size_ref->at(1)),
81-
utils::safe_downcast<float>(
82-
(float)input_size[1] / output_size_ref->at(0))};
73+
IntListPtr output_size_ref = graph.get_int_list(output_sizes);
74+
out_width = output_size_ref->at(1);
75+
out_height = output_size_ref->at(0);
76+
77+
VK_CHECK_COND(out_width == out_limits[0u]);
78+
VK_CHECK_COND(out_height == out_limits[1u]);
8379

8480
} else {
85-
auto scales = graph.get_double_list(scale_factors);
86-
rev_scales = {
87-
utils::safe_downcast<float>(1.0 / scales->at(1)),
88-
utils::safe_downcast<float>(1.0 / scales->at(0))};
81+
DoubleListPtr scales = graph.get_double_list(scale_factors);
82+
scale_factor_x = scales->at(1);
83+
scale_factor_y = scales->at(0);
84+
85+
VK_CHECK_COND(in_limits[0u] * scale_factor_x == out_width);
86+
VK_CHECK_COND(in_limits[1u] * scale_factor_y == out_height);
8987
}
9088

91-
vTensorPtr t_out = graph.get_tensor(out);
89+
recip_scale_factor_x = float(in_limits[0u] - 1) / float(out_width - 1);
90+
recip_scale_factor_y = float(in_limits[1u] - 1) / float(out_height - 1);
91+
92+
utils::vec2 recip_scales = {recip_scale_factor_x, recip_scale_factor_y};
9293

93-
std::string kernel_name("upsample_nearest2d");
94+
std::string kernel_name;
9495
kernel_name.reserve(kShaderNameReserve);
95-
add_dtype_suffix(kernel_name, *t_out);
96+
switch (mode) {
97+
case UpsampleMode::NEAREST:
98+
kernel_name = "upsample_nearest2d";
99+
break;
100+
case UpsampleMode::BILINEAR:
101+
kernel_name = "upsample_bilinear2d";
102+
break;
103+
}
104+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
96105

97106
graph.execute_nodes().emplace_back(new DispatchNode(
98107
graph,
@@ -103,21 +112,44 @@ void add_upsample_nearest2d_node(
103112
{{out, vkapi::MemoryAccessType::WRITE},
104113
{in, vkapi::MemoryAccessType::READ}},
105114
// Shader params buffers
106-
{t_out->logical_limits_ubo(),
107-
graph.create_params_buffer(input_size),
108-
graph.create_params_buffer(rev_scales)},
115+
{graph.logical_limits_ubo(out),
116+
graph.logical_limits_ubo(in),
117+
graph.create_params_buffer(recip_scales)},
109118
// Specialization Constants
110119
{},
111120
resize_upsample_nearest2d_node,
112121
{output_sizes, scale_factors}));
113122
}
114123

115-
void upsample(ComputeGraph& graph, const std::vector<ValueRef>& args) {
116-
return add_upsample_nearest2d_node(graph, args[0], args[1], args[2], args[3]);
124+
void upsample_nearest2d(
125+
ComputeGraph& graph,
126+
const std::vector<ValueRef>& args) {
127+
return add_upsample_nearest2d_node(
128+
graph,
129+
UpsampleMode::NEAREST,
130+
args[0],
131+
args[1],
132+
kDummyValueRef,
133+
args[2],
134+
args[3]);
135+
}
136+
137+
void upsample_bilinear2d(
138+
ComputeGraph& graph,
139+
const std::vector<ValueRef>& args) {
140+
return add_upsample_nearest2d_node(
141+
graph,
142+
UpsampleMode::BILINEAR,
143+
args[0],
144+
args[1],
145+
args[2],
146+
args[3],
147+
args[4]);
117148
}
118149

119150
REGISTER_OPERATORS {
120-
VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample);
151+
VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample_nearest2d);
152+
VK_REGISTER_OP(aten.upsample_bilinear2d.vec, upsample_bilinear2d);
121153
}
122154

123155
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -430,21 +430,33 @@ def get_native_layer_norm_inputs():
430430
return test_suite
431431

432432

433-
@register_test_suite("aten.upsample_nearest2d.vec")
434433
def get_upsample_inputs():
435-
test_suite = VkTestSuite(
436-
[
437-
# (input tensor shape, output 2D image size (H, W), output scaling factors)
438-
((2, 2, 2, 2), None, [1, 1]),
439-
((1, 1, 2, 2), None, [2, 2]),
440-
((1, 1, 2, 2), None, [2, 4]),
441-
((1, 1, 2, 2), None, [4, 2]),
442-
((1, 1, 2, 2), [2, 2], None),
443-
((1, 1, 2, 2), [2, 4], None),
444-
((1, 1, 2, 2), [3, 2], None),
445-
]
446-
)
447-
return test_suite
434+
inputs_list = [
435+
# (input tensor shape, output 2D image size (H, W), output scaling factors)
436+
((2, 2, 2, 2), None, [1, 1]),
437+
((1, 1, 2, 2), None, [2, 2]),
438+
((1, 1, 2, 2), None, [2, 4]),
439+
((1, 1, 2, 2), None, [4, 2]),
440+
((1, 1, 2, 2), [2, 2], None),
441+
((1, 1, 2, 2), [2, 4], None),
442+
((1, 1, 2, 2), [3, 2], None),
443+
]
444+
return inputs_list
445+
446+
447+
@register_test_suite("aten.upsample_nearest2d.vec")
448+
def get_upsample_nearest2d_inputs():
449+
inputs_list = get_upsample_inputs()
450+
return VkTestSuite(inputs_list)
451+
452+
453+
@register_test_suite("aten.upsample_bilinear2d.vec")
454+
def get_upsample_bilinear2d_inputs():
455+
base_inputs_list = get_upsample_inputs()
456+
inputs_list = []
457+
for input_case in base_inputs_list:
458+
inputs_list.append((input_case[0], input_case[1], True, input_case[2]))
459+
return VkTestSuite(inputs_list)
448460

449461

450462
@register_test_suite(["aten.full.default", "aten.full_like.default"])

0 commit comments

Comments
 (0)