Skip to content

Commit 4fb9443

Browse files
committed
Add support for upsample_nearest2d op in the Arm backend
Change-Id: Id0b742214e5432957b2f573b4218f09a4d9734e4
1 parent 03b1ef2 commit 4fb9443

File tree

9 files changed

+444
-47
lines changed

9 files changed

+444
-47
lines changed

backends/arm/arm_partitioner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import operator
1010
import os
11-
from typing import cast, final, List
11+
from typing import Callable, cast, final, List, Optional, Tuple
1212

1313
import torch
1414
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
@@ -68,6 +68,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6868
exir_ops.edge.aten.sub.Tensor,
6969
exir_ops.edge.aten.sum.dim_IntList,
7070
exir_ops.edge.aten.tanh.default,
71+
exir_ops.edge.aten.upsample_nearest2d.vec,
7172
exir_ops.edge.aten.view_copy.default,
7273
exir_ops.edge.aten.clone.default,
7374
exir_ops.edge.aten.mean.dim,
@@ -136,3 +137,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
136137
return PartitionResult(
137138
tagged_exported_program=exported_program, partition_tags=partition_tags
138139
)
140+
141+
def ops_to_not_decompose(
142+
self,
143+
ep: ExportedProgram,
144+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
145+
ops_to_not_decompose = [
146+
torch.ops.aten.upsample_nearest2d.vec,
147+
]
148+
return (ops_to_not_decompose, None)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@
3737
op_tanh,
3838
op_transpose,
3939
op_unsqueeze,
40+
op_upsample_nearest2d,
4041
op_view,
4142
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import serializer.tosa_serializer as ts
8+
import torch
9+
from executorch.backends.arm.operators.node_visitor import (
10+
NodeVisitor,
11+
register_node_visitor,
12+
)
13+
from executorch.backends.arm.tosa_mapping import TosaArg
14+
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
15+
from serializer.tosa_serializer import TosaOp
16+
17+
from tosa.ResizeMode import ResizeMode
18+
19+
20+
@register_node_visitor
21+
class UpsampleNearest2dVisitor(NodeVisitor):
22+
target = "aten.upsample_nearest2d.vec"
23+
24+
def __init__(self, *args):
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
tosa_graph: ts.TosaSerializer,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
is_quant_node: bool,
34+
) -> None:
35+
assert (
36+
inputs[0].shape is not None and output.shape is not None
37+
), "Only static shapes are supported"
38+
39+
# tosa_shape output is NHWC, take HW
40+
input_size_yx = torch.tensor(
41+
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
42+
)
43+
# Ignore scale and size parameters, directly use the output size as
44+
# we only support static shapes currently
45+
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
46+
47+
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
48+
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
49+
)
50+
51+
def in_int16_range(x):
52+
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
53+
54+
assert in_int16_range(scale_n_yx)
55+
assert in_int16_range(scale_d_yx)
56+
assert in_int16_range(border_yx)
57+
58+
attr = ts.TosaSerializerAttribute()
59+
attr.ResizeAttribute(
60+
scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]],
61+
offset=offset_yx.tolist(),
62+
border=border_yx.tolist(),
63+
mode=ResizeMode.NEAREST,
64+
)
65+
66+
tosa_graph.addOperator(
67+
TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
68+
)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ class ArmQuantizer(Quantizer):
271271
"one_to_one",
272272
"generic",
273273
"sum",
274+
"upsample_nearest2d",
274275
]
275276

276277
def __init__(self) -> None:

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,5 @@ def decorator(annotator: AnnotatorType):
6060
one_to_one_annotator,
6161
sub_annotator,
6262
sum_annotator,
63+
upsample_nearest2d_annotator,
6364
)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import itertools
7+
from typing import Callable, List, Optional
8+
9+
import torch
10+
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
11+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
12+
from torch.ao.quantization.quantizer import (
13+
QuantizationAnnotation,
14+
SharedQuantizationSpec,
15+
)
16+
from torch.fx import Node
17+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
18+
19+
20+
def _filter_upsample_nearest2d(filter_fn: Optional[Callable[[Node], bool]] = None):
21+
def filter(node: Node):
22+
is_upsample = node.target == torch.ops.aten.upsample_nearest2d.vec
23+
if filter_fn is None:
24+
return is_upsample
25+
else:
26+
return is_upsample and filter_fn(node)
27+
28+
return filter
29+
30+
31+
@register_annotator("upsample_nearest2d")
32+
def _annotate_upsample_nearest2d(
33+
gm: torch.fx.GraphModule,
34+
quantization_config: QuantizationConfig,
35+
filter_fn: Optional[Callable[[Node], bool]] = None,
36+
) -> Optional[List[List[Node]]]:
37+
module_partitions = get_source_partitions(
38+
gm.graph,
39+
[
40+
torch.nn.UpsamplingNearest2d,
41+
torch.nn.Upsample,
42+
torch.nn.functional.interpolate,
43+
],
44+
_filter_upsample_nearest2d(filter_fn),
45+
)
46+
upsample_partitions = list(
47+
itertools.chain.from_iterable(module_partitions.values())
48+
)
49+
annotated_partitions = []
50+
51+
for upsample_partition in upsample_partitions:
52+
annotated_partitions.append(upsample_partition.nodes)
53+
54+
assert len(upsample_partition.nodes) == 1
55+
upsample_node = upsample_partition.nodes[0]
56+
57+
input_act = upsample_node.args[0]
58+
assert isinstance(input_act, Node)
59+
60+
input_act_qspec = quantization_config.get_input_act_qspec()
61+
output_act_qspec = SharedQuantizationSpec((input_act, upsample_node))
62+
63+
upsample_node.meta["quantization_annotation"] = QuantizationAnnotation(
64+
input_qspec_map={
65+
input_act: input_act_qspec,
66+
},
67+
output_qspec=output_act_qspec,
68+
_annotated=True,
69+
)
70+
71+
return annotated_partitions
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Optional, Tuple
10+
11+
import torch
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
14+
from parameterized import parameterized
15+
16+
17+
test_data_suite = [
18+
# (test_name, test_data, size, scale_factor, compare_outputs)
19+
("rand_double_scale", torch.rand(2, 4, 8, 3), None, 2.0, True),
20+
("rand_double_scale_one_dim", torch.rand(2, 4, 8, 3), None, (1.0, 2.0), True),
21+
("rand_double_size", torch.rand(2, 4, 8, 3), (16, 6), None, True),
22+
("rand_one_double_scale", torch.rand(2, 4, 1, 1), None, 2.0, True),
23+
("rand_one_double_size", torch.rand(2, 4, 1, 1), (2, 2), None, True),
24+
("rand_one_same_scale", torch.rand(2, 4, 1, 1), None, 1.0, True),
25+
("rand_one_same_size", torch.rand(2, 4, 1, 1), (1, 1), None, True),
26+
# Can't compare outputs as the rounding when selecting the nearest pixel is
27+
# different between PyTorch and TOSA. Just check the legalization went well.
28+
# TODO Improve the test infrastructure to support more in depth verification
29+
# of the TOSA legalization results.
30+
("rand_half_scale", torch.rand(2, 4, 8, 6), None, 0.5, False),
31+
("rand_half_size", torch.rand(2, 4, 8, 6), (4, 3), None, False),
32+
("rand_one_and_half_scale", torch.rand(2, 4, 8, 3), None, 1.5, False),
33+
("rand_one_and_half_size", torch.rand(2, 4, 8, 3), (12, 4), None, False),
34+
]
35+
36+
37+
class TestUpsampleNearest2d(unittest.TestCase):
38+
class UpsamplingNearest2d(torch.nn.Module):
39+
def __init__(
40+
self,
41+
size: Optional[Tuple[int]],
42+
scale_factor: Optional[float | Tuple[float]],
43+
):
44+
super().__init__()
45+
self.upsample = torch.nn.UpsamplingNearest2d( # noqa: TOR101
46+
size=size, scale_factor=scale_factor
47+
)
48+
49+
def forward(self, x):
50+
return self.upsample(x)
51+
52+
class Upsample(torch.nn.Module):
53+
def __init__(
54+
self,
55+
size: Optional[Tuple[int]],
56+
scale_factor: Optional[float | Tuple[float]],
57+
):
58+
super().__init__()
59+
self.upsample = torch.nn.Upsample(
60+
size=size, scale_factor=scale_factor, mode="nearest"
61+
)
62+
63+
def forward(self, x):
64+
return self.upsample(x)
65+
66+
class Interpolate(torch.nn.Module):
67+
def __init__(
68+
self,
69+
size: Optional[Tuple[int]],
70+
scale_factor: Optional[float | Tuple[float]],
71+
):
72+
super().__init__()
73+
self.upsample = lambda x: torch.nn.functional.interpolate(
74+
x, size=size, scale_factor=scale_factor, mode="nearest"
75+
)
76+
77+
def forward(self, x):
78+
return self.upsample(x)
79+
80+
def _test_upsample_nearest_2d_tosa_MI_pipeline(
81+
self,
82+
module: torch.nn.Module,
83+
test_data: Tuple[torch.tensor],
84+
compare_outputs: bool,
85+
):
86+
tester = (
87+
ArmTester(
88+
module,
89+
example_inputs=test_data,
90+
compile_spec=common.get_tosa_compile_spec(),
91+
)
92+
.export()
93+
.check(["torch.ops.aten.upsample_nearest2d.vec"])
94+
.check_not(["torch.ops.quantized_decomposed"])
95+
.to_edge_transform_and_lower()
96+
.check_not(["torch.ops.aten.upsample_nearest2d.vec"])
97+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
98+
.to_executorch()
99+
)
100+
101+
if compare_outputs:
102+
tester.run_method_and_compare_outputs(inputs=test_data)
103+
104+
def _test_upsample_nearest_2d_tosa_BI_pipeline(
105+
self,
106+
module: torch.nn.Module,
107+
test_data: Tuple[torch.tensor],
108+
compare_outputs: bool,
109+
):
110+
tester = (
111+
ArmTester(
112+
module,
113+
example_inputs=test_data,
114+
compile_spec=common.get_tosa_compile_spec(),
115+
)
116+
.quantize()
117+
.export()
118+
.check(["torch.ops.aten.upsample_nearest2d.vec"])
119+
.check(["torch.ops.quantized_decomposed"])
120+
.to_edge_transform_and_lower()
121+
.check_not(["torch.ops.aten.upsample_nearest2d.vec"])
122+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
123+
.to_executorch()
124+
)
125+
126+
if compare_outputs:
127+
tester.run_method_and_compare_outputs(inputs=test_data)
128+
129+
@parameterized.expand(test_data_suite)
130+
def test_upsample_nearest_2d_tosa_MI(
131+
self,
132+
test_name: str,
133+
test_data: torch.Tensor,
134+
size: Optional[Tuple[int]],
135+
scale_factor: Optional[float | Tuple[float]],
136+
compare_outputs: bool,
137+
):
138+
self._test_upsample_nearest_2d_tosa_MI_pipeline(
139+
self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs
140+
)
141+
self._test_upsample_nearest_2d_tosa_MI_pipeline(
142+
self.Upsample(size, scale_factor), (test_data,), compare_outputs
143+
)
144+
self._test_upsample_nearest_2d_tosa_MI_pipeline(
145+
self.Interpolate(size, scale_factor), (test_data,), compare_outputs
146+
)
147+
148+
@parameterized.expand(test_data_suite)
149+
def test_upsample_nearest_2d_tosa_BI(
150+
self,
151+
test_name: str,
152+
test_data: torch.Tensor,
153+
size: Optional[Tuple[int]],
154+
scale_factor: Optional[float | Tuple[float]],
155+
compare_outputs: bool,
156+
):
157+
self._test_upsample_nearest_2d_tosa_BI_pipeline(
158+
self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs
159+
)
160+
self._test_upsample_nearest_2d_tosa_BI_pipeline(
161+
self.Upsample(size, scale_factor), (test_data,), compare_outputs
162+
)
163+
self._test_upsample_nearest_2d_tosa_BI_pipeline(
164+
self.Interpolate(size, scale_factor), (test_data,), compare_outputs
165+
)

0 commit comments

Comments
 (0)