Skip to content

Commit 444c0aa

Browse files
committed
Arm Backend: Update element-wise nodevisitors affected by RESCALE updates
Requires: view node visitor, rescale node visitor Change-Id: Ie04f6dc20be22a8fa67d6527eb86bf82986236d9
1 parent a38e81f commit 444c0aa

File tree

4 files changed

+476
-21
lines changed

4 files changed

+476
-21
lines changed

backends/arm/operators/op_add.py

Lines changed: 133 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111
import executorch.backends.arm.tosa_utils as tutils
1212

13-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1413
from executorch.backends.arm.operators.node_visitor import (
1514
NodeVisitor,
1615
register_node_visitor,
@@ -34,10 +33,13 @@ def __init__(self, *args):
3433
def define_node(
3534
self,
3635
node: Node,
37-
tosa_graph: ts.TosaSerializer,
36+
tosa_graph: Any,
3837
inputs: List[TosaArg],
3938
output: TosaArg,
4039
) -> None:
40+
41+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
42+
4143
# Specification (0.80) states that input and output types
4244
# should all be the same
4345
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -58,7 +60,7 @@ def define_node(
5860
if len(inputs[0].shape) > len(inputs[1].shape)
5961
else inputs[1].dim_order
6062
)
61-
63+
scale_back = 1.0
6264
if inputs[0].dtype == ts.DType.INT8:
6365
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
6466
tosa_graph, inputs, node
@@ -90,7 +92,9 @@ def define_node(
9092
if output.dtype == ts.DType.INT8:
9193
# Scale output back to 8 bit
9294
# pyre-ignore
93-
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined]
95+
tqutils.insert_rescale_op_to_int8(
96+
tosa_graph, add_output, scale_back, node
97+
) # type: ignore[possibly-undefined]
9498

9599

96100
@register_node_visitor
@@ -107,10 +111,13 @@ def __init__(self, *args):
107111
def define_node(
108112
self,
109113
node: Node,
110-
tosa_graph: ts.TosaSerializer,
114+
tosa_graph: Any,
111115
inputs: List[TosaArg],
112116
output: TosaArg,
113117
) -> None:
118+
119+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120+
114121
# Specification (0.80) states that input and output types
115122
# should all be the same
116123
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -130,7 +137,7 @@ def define_node(
130137
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
131138
)
132139

133-
input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)
140+
input1, input2 = inputs
134141

135142
# MI lowering
136143
tosa_graph.addOperator(
@@ -139,3 +146,122 @@ def define_node(
139146
[output.name],
140147
None,
141148
)
149+
150+
151+
@register_node_visitor
152+
class AddVisitor_INT(NodeVisitor):
153+
target = "aten.add.Tensor"
154+
155+
tosa_specs = [
156+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
157+
]
158+
159+
def __init__(self, *args):
160+
super().__init__(*args)
161+
162+
def define_node(
163+
self,
164+
node: Node,
165+
tosa_graph: Any,
166+
inputs: List[TosaArg],
167+
output: TosaArg,
168+
) -> None:
169+
170+
import serializer.tosa_serializer as ts # type: ignore
171+
172+
# Specification (1.0) states that input and output types
173+
# should all be the same
174+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
175+
raise TypeError(
176+
f"All IO needs to have the same data type, got input 1: "
177+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
178+
f"{output.dtype}"
179+
)
180+
# Handle int8 (quantized) and int32
181+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
182+
if inputs[0].dtype not in supported_dtypes:
183+
raise TypeError(
184+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
185+
)
186+
scale_back = 1.0
187+
if inputs[0].dtype == ts.DType.INT8:
188+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
189+
tosa_graph, inputs, node, self.tosa_specs
190+
)
191+
else:
192+
# input[0].dtype == ts.DType.INT32
193+
# Non quantized input, natively support by TOSA.ADD
194+
rescaled_inputs = inputs
195+
196+
if output.dtype == ts.DType.INT8:
197+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
198+
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
199+
else:
200+
# output.dtype == ts.DType.INT32
201+
add_output = output
202+
203+
input1, input2 = rescaled_inputs
204+
205+
# Do the INT32 Add
206+
tosa_graph.addOperator(
207+
ts.TosaOp.Op().ADD,
208+
[input1.name, input2.name],
209+
[add_output.name],
210+
None,
211+
)
212+
213+
if output.dtype == ts.DType.INT8:
214+
# Scale output back to 8 bit
215+
# pyre-ignore
216+
tqutils.insert_rescale_op_to_int8(
217+
tosa_graph, add_output, scale_back, node, self.tosa_specs
218+
) # type: ignore[possibly-undefined]
219+
220+
221+
@register_node_visitor
222+
class AddVisitor_FP(AddVisitor_INT):
223+
# inheriting 'target' from INT class
224+
225+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
226+
227+
def __init__(self, *args):
228+
super().__init__(*args)
229+
230+
def define_node(
231+
self,
232+
node: Node,
233+
tosa_graph: Any,
234+
inputs: List[TosaArg],
235+
output: TosaArg,
236+
) -> None:
237+
238+
import serializer.tosa_serializer as ts # type: ignore
239+
240+
# Specification (1.0) states that input and output types
241+
# should all be the same
242+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
243+
raise TypeError(
244+
f"All IO needs to have the same data type, got input 1: "
245+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
246+
f"{output.dtype}"
247+
)
248+
249+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
250+
# Call the inherited define_node for handling integers
251+
super().define_node(node, tosa_graph, inputs, output)
252+
else:
253+
# FP32 Add lowering
254+
if inputs[0].dtype != ts.DType.FP32:
255+
raise TypeError(
256+
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
257+
)
258+
259+
input1, input2 = inputs
260+
261+
# FP lowering
262+
tosa_graph.addOperator(
263+
ts.TosaOp.Op().ADD,
264+
[input1.name, input2.name],
265+
[output.name],
266+
None,
267+
)

backends/arm/operators/op_mul.py

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111
import executorch.backends.arm.tosa_utils as tutils
1212
import torch
1313

14-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
15-
1614
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1715
get_input_qparams,
1816
)
@@ -37,10 +35,13 @@ class MulVisitor_080_BI(NodeVisitor):
3735
def define_node(
3836
self,
3937
node: torch.fx.Node,
40-
tosa_graph: ts.TosaSerializer,
38+
tosa_graph: Any,
4139
inputs: List[TosaArg],
4240
output: TosaArg,
4341
) -> None:
42+
43+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
44+
4445
if (
4546
inputs[0].dtype != ts.DType.INT8
4647
or inputs[1].dtype != ts.DType.INT8
@@ -114,10 +115,13 @@ class MulVisitor_080_MI(MulVisitor_080_BI):
114115
def define_node(
115116
self,
116117
node: torch.fx.Node,
117-
tosa_graph: ts.TosaSerializer,
118+
tosa_graph: Any,
118119
inputs: List[TosaArg],
119120
output: TosaArg,
120121
) -> None:
122+
123+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
124+
121125
if inputs[0].dtype == ts.DType.INT8:
122126
return super().define_node(node, tosa_graph, inputs, output)
123127

@@ -128,3 +132,100 @@ def define_node(
128132
tosa_graph.addOperator(
129133
ts.TosaOp.Op().MUL, [input1.name, input2.name], [output.name], attr
130134
)
135+
136+
137+
@register_node_visitor
138+
class MulVisitor_INT(NodeVisitor):
139+
target = "aten.mul.Tensor"
140+
141+
tosa_specs = [
142+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
143+
]
144+
145+
def define_node(
146+
self,
147+
node: torch.fx.Node,
148+
tosa_graph: Any,
149+
inputs: List[TosaArg],
150+
output: TosaArg,
151+
) -> None:
152+
153+
import serializer.tosa_serializer as ts # type: ignore
154+
155+
if (
156+
inputs[0].dtype != ts.DType.INT8
157+
or inputs[1].dtype != ts.DType.INT8
158+
or output.dtype != ts.DType.INT8
159+
):
160+
raise ValueError(
161+
f"Inputs and output for {self.target} need to be INT8, got "
162+
f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}"
163+
)
164+
165+
input_A = inputs[0]
166+
input_B = inputs[1]
167+
input_qparams = get_input_qparams(node)
168+
input_A_qargs = input_qparams[0]
169+
input_B_qargs = input_qparams[1]
170+
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
171+
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
172+
173+
# Rescale inputs to INT32 with zp=0
174+
input_A_rescaled = tqutils.build_rescale_to_int32(
175+
tosa_graph,
176+
input_A,
177+
input_A_qargs.zp,
178+
[1.0],
179+
tosa_spec=self.tosa_specs,
180+
)
181+
input_B_rescaled = tqutils.build_rescale_to_int32(
182+
tosa_graph,
183+
input_B,
184+
input_B_qargs.zp,
185+
[1.0],
186+
tosa_spec=self.tosa_specs,
187+
)
188+
189+
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
190+
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
191+
192+
# Do the INT32 Mul
193+
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
194+
tosa_graph.addOperator(
195+
ts.TosaOp.Op().MUL,
196+
[input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"],
197+
[mul_output.name],
198+
)
199+
output_scale = input_A_qargs.scale * input_B_qargs.scale
200+
tqutils.insert_rescale_op_to_int8(
201+
tosa_graph, mul_output, output_scale, node, self.tosa_specs
202+
)
203+
204+
205+
@register_node_visitor
206+
class MulVisitor_FP(MulVisitor_INT):
207+
# inheriting 'target' from INT class
208+
209+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
210+
211+
def define_node(
212+
self,
213+
node: torch.fx.Node,
214+
tosa_graph: Any,
215+
inputs: List[TosaArg],
216+
output: TosaArg,
217+
) -> None:
218+
219+
import serializer.tosa_serializer as ts # type: ignore
220+
221+
if inputs[0].dtype == ts.DType.INT8:
222+
return super().define_node(node, tosa_graph, inputs, output)
223+
224+
input1, input2 = inputs
225+
226+
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
227+
tosa_graph.addOperator(
228+
ts.TosaOp.Op().MUL,
229+
[input1.name, input2.name, f"{node.name}_shift"],
230+
[output.name],
231+
)

0 commit comments

Comments
 (0)