Skip to content

Commit ac2c073

Browse files
Arm backend: Add TOSA dialect op for RESIZE (#14513)
Add TOSA backend dialect op for TOSA RESIZE. The dialect op replaces upsample_nearest2d and upsample_bilinear_2d in RewriteUpsamplePass. Also the Nodevisitors of upsample_nearest2d and upsample_bilinear2d are replaced by one NodeVisitor for the resize backend dialect op. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 9602b2e commit ac2c073

File tree

9 files changed

+165
-158
lines changed

9 files changed

+165
-158
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
ReplaceScalarWithTensorArgPassTOSABI,
9292
ReplaceScalarWithTensorArgPassTOSAMI,
9393
)
94+
from .rewrite_upsample import RewriteUpsamplePass # noqa
9495
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
9596
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
9697
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
ReplaceScalarWithTensorArgPassTOSABI,
9292
ReplaceScalarWithTensorArgPassTOSAMI,
9393
RetraceFoldedDtypesPass,
94+
RewriteUpsamplePass,
9495
ScalarsToAttributePass,
9596
SizeAdjustInputPass,
9697
ToTosaMemoryFormatPass,
@@ -206,6 +207,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
206207
# needs to happen before AddBiasPass, but after the table ops are inserted
207208
# to be able to validate that conv2d has right dtype arguments.
208209
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
210+
self.add_pass(RewriteUpsamplePass(exported_program))
209211
self.add_pass(AddBiasPass(exported_program))
210212

211213
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
@@ -290,6 +292,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
290292
self.add_pass(FuseViewCopyTransform())
291293
self.add_pass(FuseConstantArgsPass(exported_program))
292294
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
295+
self.add_pass(RewriteUpsamplePass(exported_program))
293296
self.add_pass(AddBiasPass(exported_program))
294297
self.add_pass(InsertTableOpsPass(exported_program))
295298
self.add_pass(FuseEqualPlaceholdersPass(exported_program))

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ def call(self, graph_module):
114114
if node.op != "call_function":
115115
continue
116116
if node.target in [
117-
exir_ops.backend.tosa.TABLE.default,
118117
exir_ops.backend.tosa.RESCALE.default,
118+
exir_ops.backend.tosa.RESIZE.default,
119+
exir_ops.backend.tosa.TABLE.default,
119120
exir_ops.backend.tosa.TRANSPOSE.default,
120121
]:
121122
continue
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
14+
from executorch.backends.arm.tosa.utils import get_resize_parameters
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
from executorch.exir.pass_base import ExportPass, PassResult
17+
18+
19+
class RewriteUpsamplePass(ArmPass):
20+
"""Rewrite upsample2d nodes to TOSA.RESIZE nodes."""
21+
22+
targeted_ops = (
23+
exir_ops.edge.aten.upsample_nearest2d.vec,
24+
exir_ops.edge.aten.upsample_bilinear2d.vec,
25+
)
26+
27+
_passes_required_after: Set[Type[ExportPass]] = set()
28+
29+
def call(self, graph_module):
30+
modified = False
31+
for node in graph_module.graph.nodes:
32+
if node.op != "call_function" or node.target not in self.targeted_ops:
33+
continue
34+
modified = True
35+
36+
if node.target == exir_ops.edge.aten.upsample_bilinear2d.vec:
37+
x, output_size, align_corners, scale_factors = node.args
38+
resize_mode = "bilinear"
39+
else:
40+
x, output_size, scale_factors = node.args
41+
align_corners = False
42+
resize_mode = "nearest"
43+
44+
with graph_module.graph.inserting_before(node):
45+
tosa_resize_node = create_node(
46+
graph_module.graph,
47+
op_target=exir_ops.backend.tosa.RESIZE.default,
48+
args=(x, output_size, align_corners, scale_factors),
49+
kwargs={"resize_mode": resize_mode},
50+
from_node=node,
51+
)
52+
node.replace_all_uses_with(tosa_resize_node)
53+
graph_module.graph.erase_node(node)
54+
input_dtype = get_first_fake_tensor(x).dtype
55+
if input_dtype == torch.int8 and resize_mode == "bilinear":
56+
input_size = get_first_fake_tensor(x).shape
57+
input_size_xy = input_size[2:]
58+
output_size = get_first_fake_tensor(node).shape
59+
output_size_xy = output_size[2:]
60+
scale_n_yx, _, _, _ = get_resize_parameters(
61+
input_size_xy=input_size_xy,
62+
output_size_xy=output_size_xy,
63+
resize_mode=1,
64+
align_corners=align_corners,
65+
)
66+
output_dtype = get_first_fake_tensor(node).dtype
67+
output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1]))
68+
with graph_module.graph.inserting_after(tosa_resize_node):
69+
rescale_node = create_node(
70+
graph_module.graph,
71+
exir_ops.backend.tosa.RESCALE.default,
72+
)
73+
tosa_resize_node.replace_all_uses_with(rescale_node)
74+
rescale_node.args = (
75+
tosa_resize_node,
76+
output_dtype,
77+
output_scale,
78+
0, # zero point
79+
0, # zero point
80+
)
81+
82+
if modified:
83+
graph_module = super().call(graph_module).graph_module
84+
return PassResult(graph_module, modified)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
op_reciprocal,
4444
op_repeat,
4545
op_rescale,
46+
op_resize,
4647
op_rshift_tensor,
4748
op_rsqrt,
4849
op_sigmoid,
@@ -54,8 +55,6 @@
5455
op_tanh,
5556
op_to_dim_order_copy,
5657
op_transpose,
57-
op_upsample_bilinear2d,
58-
op_upsample_nearest2d,
5958
op_view,
6059
op_where,
6160
ops_binary,

backends/arm/operators/op_upsample_nearest2d.py renamed to backends/arm/operators/op_resize.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525

2626
@register_node_visitor
27-
class UpsampleNearest2dVisitor(NodeVisitor):
28-
target = "aten.upsample_nearest2d.vec"
27+
class ResizeVisitor(NodeVisitor):
28+
target = "tosa.RESIZE.default"
2929

3030
tosa_specs = NodeVisitor.tosa_specs
3131

@@ -41,12 +41,18 @@ def define_node(
4141
) -> None:
4242
import serializer.tosa_serializer as ts
4343

44-
validate_num_inputs(self.target, inputs, 3)
45-
validate_same_dtype(self.target, [inputs[0], output], ts)
44+
validate_num_inputs(self.target, inputs, [3, 4])
45+
if node.kwargs.get("resize_mode") == "bilinear":
46+
resize_mode = ResizeMode.BILINEAR
47+
align_corners = bool(node.args[2])
48+
else:
49+
resize_mode = ResizeMode.NEAREST
50+
align_corners = False
51+
validate_same_dtype(self.target, [inputs[0], output], ts)
4652
validate_valid_dtype(
4753
self.target,
4854
[inputs[0], output],
49-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
55+
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP16, ts.DType.FP32],
5056
output.tosa_spec,
5157
)
5258

@@ -59,7 +65,7 @@ def define_node(
5965
# Align corners shouldn't make a difference for nearest upsampling. We set to False so
6066
# half pixel centers are used for resize parameter logic.
6167
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
62-
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=False
68+
input_size_yx, output_size_yx, resize_mode, align_corners=align_corners
6369
)
6470

6571
def in_int16_range(x):
@@ -86,7 +92,7 @@ def in_int16_range(x):
8692
)
8793
attr = ts.TosaSerializerAttribute()
8894
attr.ResizeAttribute(
89-
mode=ResizeMode.NEAREST,
95+
mode=resize_mode,
9096
)
9197

9298
self._serialize_operator(

backends/arm/operators/op_upsample_bilinear2d.py

Lines changed: 0 additions & 148 deletions
This file was deleted.

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401
77
rescale,
8+
resize,
89
table,
910
transpose,
1011
)

0 commit comments

Comments
 (0)