Skip to content

Commit 9ac84b2

Browse files
perSaoirseARM
authored andcommitted
Arm backend: Support for rescale for TOSA 1.0
Signed-off-by: Per Åstrand <[email protected]> Change-Id: I88822c93470c4ff54f0e75df2f5ff271fa6bd1e6
1 parent 9e0292f commit 9ac84b2

File tree

1 file changed

+78
-7
lines changed

1 file changed

+78
-7
lines changed

backends/arm/operators/op_rescale.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,35 @@
55

66
# pyre-unsafe
77

8-
from typing import cast, List
8+
from typing import Any, cast, List
99

1010
import executorch.backends.arm.tosa_quant_utils as tosa_quant_utils
1111
import torch
12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
13-
14-
import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore
1512
from executorch.backends.arm.operators.node_visitor import (
1613
NodeVisitor,
1714
register_node_visitor,
1815
)
1916
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
17+
from executorch.backends.arm.tosa_quant_utils import create_const_ops_for_rescale
18+
19+
from executorch.backends.arm.tosa_specification import TosaSpecification
2020
from torch.fx import Node
2121

2222

2323
@register_node_visitor
24-
class RescaleVisitor(NodeVisitor):
24+
class RescaleVisitor_0_80(NodeVisitor):
2525
target = "_rescale.default"
2626

27+
tosa_specs = NodeVisitor.tosa_specs_0_80
28+
2729
def define_node(
2830
self,
2931
node: Node,
30-
tosa_graph: ts.TosaSerializer,
32+
tosa_graph: Any,
3133
inputs: List[TosaArg],
3234
output: TosaArg,
3335
) -> None:
36+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3437

3538
input_dtype = inputs[0].dtype
3639
output_dtype = cast(torch.dtype, node.args[1])
@@ -68,5 +71,73 @@ def define_node(
6871
)
6972

7073
tosa_graph.addOperator(
71-
TosaOp.Op().RESCALE, [inputs[0].name], [output.name], attr_rescale
74+
ts.TosaOp.Op().RESCALE, [inputs[0].name], [output.name], attr_rescale
75+
)
76+
77+
78+
@register_node_visitor
79+
class RescaleVisitor_INT(NodeVisitor):
80+
target = "_rescale.default"
81+
82+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+INT")]
83+
84+
def define_node(
85+
self,
86+
node: Node,
87+
tosa_graph: Any,
88+
inputs: List[TosaArg],
89+
output: TosaArg,
90+
) -> None:
91+
import serializer.tosa_serializer as ts # type: ignore
92+
from tosa.RoundingMode import RoundingMode # type: ignore
93+
94+
input_dtype = inputs[0].dtype
95+
output_dtype = cast(torch.dtype, node.args[1])
96+
scale = cast(float, node.args[2])
97+
input_zp = cast(int, node.args[3])
98+
output_zp = cast(int, node.args[4])
99+
100+
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
101+
raise ValueError(
102+
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
103+
)
104+
if output_dtype != torch.int8 and output_zp != 0:
105+
raise ValueError(
106+
f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}"
107+
)
108+
109+
# scale32 gives higher accuracy but for a higher HW cost.
110+
# For now, always go for scale32.
111+
scale_32 = True
112+
scale_width = 32 if scale_32 else 16
113+
multipliers, shifts = tosa_quant_utils.compute_multiplier_and_shift(
114+
[scale], scale_width
115+
)
116+
117+
rescale_inputs = create_const_ops_for_rescale(
118+
tosa_graph,
119+
input_dtype,
120+
inputs[0].name,
121+
multipliers,
122+
shifts,
123+
input_zp,
124+
output_zp,
125+
ts,
126+
)
127+
128+
attr_rescale = ts.TosaSerializerAttribute()
129+
130+
attr_rescale.RescaleAttribute(
131+
scale32=scale_32,
132+
rounding_mode=RoundingMode.SINGLE_ROUND,
133+
per_channel=False,
134+
input_unsigned=False,
135+
output_unsigned=False,
136+
)
137+
138+
tosa_graph.addOperator(
139+
ts.TosaOp.Op().RESCALE,
140+
[inputs[0].name, *rescale_inputs],
141+
[output.name],
142+
attr_rescale,
72143
)

0 commit comments

Comments
 (0)