Skip to content

Commit a765d7e

Browse files
committed
Arm backend: Update nodevisitors affected by RESCALE updates
- Updated abs, eq, ge, gt, le, lt, maximum (Nan attribute added), minimum (Nan attribute added) Signed-off-by: Saoirse Stewart <[email protected]> Change-Id: Ib85d698c3b22ec9d8506d0ccb08e23966ec9e018
1 parent 444c0aa commit a765d7e

File tree

8 files changed

+599
-34
lines changed

8 files changed

+599
-34
lines changed

backends/arm/operators/op_abs.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import executorch.backends.arm.tosa_quant_utils as tqutils
1010
import executorch.backends.arm.tosa_utils as tutils
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.operators.node_visitor import (
1413
NodeVisitor,
1514
register_node_visitor,
@@ -33,10 +32,13 @@ def __init__(self, *args):
3332
def define_node(
3433
self,
3534
node: Node,
36-
tosa_graph: ts.TosaSerializer,
35+
tosa_graph: Any,
3736
inputs: List[TosaArg],
3837
output: TosaArg,
3938
) -> None:
39+
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
41+
4042
# Specification (0.80) states that input and output types
4143
# should all be the same
4244
if not (inputs[0].dtype == output.dtype):
@@ -53,7 +55,7 @@ def define_node(
5355
if inputs[0].dtype == ts.DType.INT8:
5456
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5557
tosa_graph, inputs, node
56-
)
58+
) # type: ignore[possibly-undefined]
5759
else:
5860
# input[0].dtype == ts.DType.INT32
5961
# Non quantized input, natively support by TOSA.abs
@@ -96,10 +98,13 @@ def __init__(self, *args):
9698
def define_node(
9799
self,
98100
node: Node,
99-
tosa_graph: ts.TosaSerializer,
101+
tosa_graph: Any,
100102
inputs: List[TosaArg],
101103
output: TosaArg,
102104
) -> None:
105+
106+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107+
103108
# Specification (0.80) states that input and output types
104109
# should all be the same
105110
if not (inputs[0].dtype == output.dtype):
@@ -129,3 +134,122 @@ def define_node(
129134
[output.name],
130135
None,
131136
)
137+
138+
139+
@register_node_visitor
140+
class AbsVisitor_INT(NodeVisitor):
141+
target = "aten.abs.default"
142+
143+
tosa_specs = [
144+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145+
]
146+
147+
def __init__(self, *args):
148+
super().__init__(*args)
149+
150+
def define_node(
151+
self,
152+
node: Node,
153+
tosa_graph: Any,
154+
inputs: List[TosaArg],
155+
output: TosaArg,
156+
) -> None:
157+
158+
import serializer.tosa_serializer as ts # type: ignore
159+
160+
# Specification (1.0) states that input and output types
161+
# should all be the same
162+
if not (inputs[0].dtype == output.dtype):
163+
raise ValueError(
164+
"All inputs and outputs need same dtype."
165+
f"Got {inputs[0].dtype=}, {output.dtype=}"
166+
)
167+
# Handle int8 (quantized) and int32
168+
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
169+
raise ValueError(
170+
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
171+
)
172+
173+
scale_back = 1.0
174+
if inputs[0].dtype == ts.DType.INT8:
175+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
176+
tosa_graph, inputs, node, self.tosa_specs
177+
) # type: ignore[possibly-undefined]
178+
else:
179+
# input[0].dtype == ts.DType.INT32
180+
# Non quantized input, natively support by TOSA.abs
181+
rescaled_inputs = inputs
182+
183+
if output.dtype == ts.DType.INT8:
184+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
185+
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
186+
else:
187+
# output.dtype == ts.DType.INT32
188+
abs_output = output
189+
190+
# Do the INT32 Abs
191+
tosa_graph.addOperator(
192+
ts.TosaOp.Op().ABS,
193+
[
194+
rescaled_inputs[0].name,
195+
],
196+
[abs_output.name],
197+
None,
198+
)
199+
200+
if output.dtype == ts.DType.INT8:
201+
# Scale output back to 8 bit
202+
# pyre-ignore
203+
tqutils.insert_rescale_op_to_int8(
204+
tosa_graph, abs_output, scale_back, node, self.tosa_specs
205+
) # type: ignore[possibly-undefined]
206+
207+
208+
@register_node_visitor
209+
class AbsVisitor_FP(AbsVisitor_INT):
210+
# inheriting 'target' from BI class
211+
212+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
213+
214+
def __init__(self, *args):
215+
super().__init__(*args)
216+
217+
def define_node(
218+
self,
219+
node: Node,
220+
tosa_graph: Any,
221+
inputs: List[TosaArg],
222+
output: TosaArg,
223+
) -> None:
224+
225+
import serializer.tosa_serializer as ts # type: ignore
226+
227+
# Specification (1.0) states that input and output types
228+
# should all be the same
229+
if not (inputs[0].dtype == output.dtype):
230+
raise ValueError(
231+
"All inputs and output need same dtype."
232+
f"Got {inputs[0].dtype=}, {output.dtype=}"
233+
)
234+
235+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
236+
# Call the inherited define_node for handling integers
237+
super().define_node(node, tosa_graph, inputs, output)
238+
else:
239+
# FP32 Abs lowering
240+
241+
if not (inputs[0].dtype == ts.DType.FP32):
242+
raise ValueError(
243+
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
244+
)
245+
246+
if not (output.dtype == ts.DType.FP32):
247+
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")
248+
249+
# MI lowering
250+
tosa_graph.addOperator(
251+
ts.TosaOp.Op().ABS,
252+
[inputs[0].name],
253+
[output.name],
254+
None,
255+
)

backends/arm/operators/op_eq.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,42 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.operators.node_visitor import (
1413
NodeVisitor,
1514
register_node_visitor,
1615
)
1716
from executorch.backends.arm.tosa_mapping import TosaArg
17+
from executorch.backends.arm.tosa_specification import TosaSpecification
1818

1919
from torch.fx import Node
2020

2121

2222
@register_node_visitor
23-
class EqualVisitor(NodeVisitor):
23+
class EqualVisitor_0_80(NodeVisitor):
2424
target = "aten.eq.Tensor"
2525

26+
tosa_specs = [
27+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
28+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
29+
]
30+
2631
def __init__(self, *args):
2732
super().__init__(*args)
2833

2934
def define_node(
3035
self,
3136
node: Node,
32-
tosa_graph: ts.TosaSerializer,
37+
tosa_graph: Any,
3338
inputs: List[TosaArg],
3439
output: TosaArg,
3540
) -> None:
41+
42+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
43+
3644
if inputs[0].dtype != inputs[1].dtype:
3745
raise TypeError(
3846
"All inputs need to have the same data type for operator EQ but got "
@@ -57,3 +65,51 @@ def define_node(
5765
output.name,
5866
None,
5967
)
68+
69+
70+
@register_node_visitor
71+
class EqualVisitor(NodeVisitor):
72+
target = "aten.eq.Tensor"
73+
74+
tosa_specs = [
75+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
76+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
77+
]
78+
79+
def __init__(self, *args):
80+
super().__init__(*args)
81+
82+
def define_node(
83+
self,
84+
node: Node,
85+
tosa_graph: Any,
86+
inputs: List[TosaArg],
87+
output: TosaArg,
88+
) -> None:
89+
90+
import serializer.tosa_serializer as ts # type: ignore
91+
92+
if inputs[0].dtype != inputs[1].dtype:
93+
raise TypeError(
94+
"All inputs need to have the same data type for operator EQ but got "
95+
f"{inputs[0].dtype=}, {inputs[1].dtype=}"
96+
)
97+
98+
input_nodes = inputs
99+
# Handle quantization
100+
if inputs[0].dtype == ts.DType.INT8:
101+
# Rescale inputs to 32 bit
102+
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
103+
tosa_graph, inputs, node, self.tosa_specs
104+
)
105+
106+
# Update IO
107+
input_nodes = rescaled_inputs
108+
109+
# Do the equal comparison
110+
tosa_graph.addOperator(
111+
ts.TosaOp.Op().EQUAL,
112+
[input_nodes[0].name, input_nodes[1].name],
113+
output.name,
114+
None,
115+
)

backends/arm/operators/op_ge.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,42 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.operators.node_visitor import (
1413
NodeVisitor,
1514
register_node_visitor,
1615
)
1716
from executorch.backends.arm.tosa_mapping import TosaArg
17+
from executorch.backends.arm.tosa_specification import TosaSpecification
1818

1919
from torch.fx import Node
2020

2121

2222
@register_node_visitor
23-
class GreaterEqualVisitor(NodeVisitor):
23+
class GreaterEqualVisitor_0_80(NodeVisitor):
2424
target = "aten.ge.Tensor"
2525

26+
tosa_specs = [
27+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
28+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
29+
]
30+
2631
def __init__(self, *args):
2732
super().__init__(*args)
2833

2934
def define_node(
3035
self,
3136
node: Node,
32-
tosa_graph: ts.TosaSerializer,
37+
tosa_graph: Any,
3338
inputs: List[TosaArg],
3439
output: TosaArg,
3540
) -> None:
41+
42+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
43+
3644
if inputs[0].dtype != inputs[1].dtype:
3745
raise TypeError(
3846
"All inputs need to have the same data type for operator GE but got "
@@ -56,3 +64,50 @@ def define_node(
5664
[output.name],
5765
None,
5866
)
67+
68+
69+
@register_node_visitor
70+
class GreaterEqualVisitor(NodeVisitor):
71+
target = "aten.ge.Tensor"
72+
73+
tosa_specs = [
74+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
75+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
76+
]
77+
78+
def __init__(self, *args):
79+
super().__init__(*args)
80+
81+
def define_node(
82+
self,
83+
node: Node,
84+
tosa_graph: Any,
85+
inputs: List[TosaArg],
86+
output: TosaArg,
87+
) -> None:
88+
89+
import serializer.tosa_serializer as ts # type: ignore
90+
91+
if inputs[0].dtype != inputs[1].dtype:
92+
raise TypeError(
93+
"All inputs need to have the same data type for operator GE but got "
94+
f"{inputs[0].dtype=}, {inputs[1].dtype=}"
95+
)
96+
97+
input_nodes = inputs
98+
# Handle quantization
99+
if inputs[0].dtype == ts.DType.INT8:
100+
# Rescale inputs to 32 bit
101+
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
102+
tosa_graph, inputs, node, self.tosa_specs
103+
)
104+
105+
# Update IO
106+
input_nodes = rescaled_inputs
107+
108+
tosa_graph.addOperator(
109+
ts.TosaOp.Op().GREATER_EQUAL,
110+
[input_nodes[0].name, input_nodes[1].name],
111+
[output.name],
112+
None,
113+
)

0 commit comments

Comments
 (0)