Skip to content

Commit 13569b7

Browse files
authored
Arm backend: Move TOSA operators to dialect (#13408)
### Summary Move rescale, table and transpose TOSA operators to new implementation. ### Test plan Tested through existing CI unit tests. Signed-off-by: Per Åstrand <[email protected]>
1 parent 9c67384 commit 13569b7

16 files changed

+263
-131
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,12 @@
1414
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass, PassResult
17-
from torch.library import impl, Library
18-
19-
# Define lib with passthrough operators. The operators have no real meaning in edge IR
20-
# except for argument validaiton and a passthrough output. The operators will be used
21-
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
22-
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
23-
lib = Library("passthrough_to_tosa", "DEF")
24-
# For certain operators we need the data in a specific data format. Changing tosa_dim_order
25-
# is not sufficient as we also need transpose the data.
26-
# By utilizing an edge IR passthrough operator we can keep the edge program in
27-
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
28-
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")
29-
30-
31-
@impl(lib, "_transpose")
32-
def _transpose_impl(*args, **kwargs):
33-
# Validate length of dim_order array
34-
dim = args[1]
35-
if len(dim) != 4 and len(dim) != 5:
36-
raise ValueError(
37-
f"Dim order length must be either 4 or 5, got {len(dim)}: {dim}"
38-
)
39-
# Pass-through in edge-IR
40-
return args[0]
4117

4218

4319
class AnnotateChannelsLastDimOrder(ExportPass):
4420
"""
4521
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
46-
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
22+
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE
4723
when a transition between 3D and 4D/5D tensors happen.
4824
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
4925
"""
@@ -119,7 +95,7 @@ def insert_input_transpose(node, input_node, graph_module):
11995
with graph_module.graph.inserting_before(node):
12096
permute_node = create_node(
12197
graph_module.graph,
122-
torch.ops.passthrough_to_tosa._transpose.default,
98+
exir_ops.backend.tosa.TRANSPOSE.default,
12399
args=(
124100
input_node,
125101
list(
@@ -141,7 +117,7 @@ def insert_output_transpose(node, graph_module):
141117
with graph_module.graph.inserting_after(node):
142118
permute_node = create_node(
143119
graph_module.graph,
144-
torch.ops.passthrough_to_tosa._transpose.default,
120+
exir_ops.backend.tosa.TRANSPOSE.default,
145121
args=(
146122
node,
147123
list(

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def call(self, graph_module):
107107
for node in graph_module.graph.nodes:
108108
if node.op != "call_function":
109109
continue
110-
if node.target == torch.ops.tosa._table.default:
110+
if node.target == exir_ops.backend.tosa.TABLE.default:
111111
continue
112112

113113
input_nodes = node.all_input_nodes

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 5 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,70 +3,25 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import logging
76
from copy import copy
87
from typing import cast
98

10-
import torch
119
from executorch.backends.arm._passes.arm_pass_utils import create_node
1210
from executorch.backends.arm._passes.quant_args import QuantArgs
1311
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
12+
from executorch.exir.dialects._ops import ops as exir_ops
1413
from executorch.exir.pass_base import ExportPass, PassResult
15-
from torch import Tensor
1614
from torch.fx import GraphModule, Node
17-
from torch.library import custom_op, register_fake
18-
19-
logger = logging.getLogger(__name__)
20-
21-
22-
@custom_op("tosa::_rescale", mutates_args=()) # type: ignore[misc]
23-
def rescale(
24-
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
25-
) -> Tensor:
26-
logger.warning(
27-
"Ran default implementation of tosa::_rescale."
28-
"This op is meant to always be inserted inside a partition and a correct default implementation is not implemented."
29-
)
30-
# Clone is needed to not return reference when rescaling to same dtype.
31-
# This is a neccessary requirement for non-mutating custom ops.
32-
return x.to(dtype=dtype).clone()
33-
34-
35-
@register_fake("tosa::_rescale") # type: ignore[misc]
36-
def rescale_fake(
37-
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
38-
) -> Tensor:
39-
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
40-
Additionally validates TOSA constraints of a RESCALE op.
41-
"""
42-
if dtype not in (torch.int32, torch.int8, torch.int16):
43-
raise NotImplementedError(
44-
f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}"
45-
)
46-
if dtype in (torch.int32, torch.int16) and out_zp != 0:
47-
raise ValueError(
48-
f"TOSA requires output_zp to be zero when the output dtype is {dtype}."
49-
)
50-
if x.dtype in (torch.int32, torch.int16) and in_zp != 0:
51-
raise ValueError(
52-
f"TOSA requires input_zp to be zero when the input dtype is {dtype}"
53-
)
54-
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
55-
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")
56-
if dtype == torch.int8 and not -128 <= out_zp <= 127:
57-
raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.")
58-
59-
return x.to(dtype=dtype).clone()
6015

6116

6217
class InsertRescalePass(ExportPass):
6318
"""Finds patterns of dq -> q, and replaces them
64-
with passthrough_to_tosa::rescales.
19+
with backend dialect tosa::RESCALE op.
6520
66-
Does not garantuee that the dtypes and zero points are valid
21+
Does not guarantee that the dtypes and zero points are valid
6722
in TOSA, that is the job of the quantization annotator that
6823
produced the dq and q nodes. The TOSA constraints are validated
69-
in the fake implementation of passthrough_to_tosa:rescale.
24+
in the fake implementation of.
7025
"""
7126

7227
def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
@@ -77,7 +32,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule
7732
with graph_module.graph.inserting_before(node):
7833
rescale_node = create_node(
7934
graph_module.graph,
80-
torch.ops.tosa._rescale.default,
35+
exir_ops.backend.tosa.RESCALE.default,
8136
(
8237
node.all_input_nodes[0],
8338
q_args.dtype,

backends/arm/_passes/insert_table_ops.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,17 @@
1111
import torch
1212
from executorch.backends.arm._passes.arm_pass_utils import create_node
1313
from executorch.backends.arm._passes.quant_args import QuantArgs
14+
from executorch.backends.transforms.utils import create_constant_placeholder
15+
1416
from executorch.exir import ExportedProgram
1517

1618
from executorch.exir.dialects._ops import ops as exir_ops
1719
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1820

1921
from executorch.exir.pass_base import ExportPass, PassResult
22+
from torch.export.graph_signature import InputKind
2023
from torch.fx import GraphModule
2124
from torch.fx.node import Node
22-
from torch.library import impl, Library
23-
24-
lib = Library("tosa", "DEF")
25-
lib.define("_table(Tensor self) -> Tensor")
26-
27-
28-
@impl(lib, "_table")
29-
def _table_impl(*args, **kwargs): # pyre-ignore
30-
in_dtype = args[0].dtype
31-
if in_dtype == torch.int8:
32-
return args[0]
33-
return args[0].to(dtype=torch.int32)
3425

3526

3627
class TableOps:
@@ -242,13 +233,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
242233
# We only want to replace the node if it's quantized
243234
continue
244235
# Create table node
245-
with graph_module.graph.inserting_before(node):
246-
table_node = create_node(
247-
graph=graph_module.graph,
248-
op_target=torch.ops.tosa._table.default,
249-
args=(node.args[0],),
250-
)
251-
output_node = table_node
236+
insert_pos = list(node.graph.nodes)[0]
237+
with graph_module.graph.inserting_before(insert_pos):
252238
# Expect exactly one quantization parameter for input and output
253239
if len(input_qparams) != 1:
254240
raise ValueError(
@@ -268,27 +254,37 @@ def call(self, graph_module: GraphModule) -> PassResult:
268254
out_quantargs=output_qparams[0],
269255
)
270256
# Register buffer in self.exported_program.state_dict
271-
# When the graph is retraced, the implementation _table is used and the suffix _default disappears from the node name
272-
# Remove it here to make it possible to find in the node_visitor
273-
self.register_buffer(
274-
buffer_name=table_node.name.replace("_default", ""), buffer=buffer
257+
const_table_node = create_constant_placeholder(
258+
exp_program=self.exported_program,
259+
graph=node.graph,
260+
kind=InputKind.BUFFER,
261+
name=node.name + "_table_constant",
262+
data=buffer,
263+
persistent_buffer=True,
275264
)
276265

266+
# Create table node
267+
with graph_module.graph.inserting_before(node):
268+
table_op_node = create_node(
269+
graph=graph_module.graph,
270+
op_target=exir_ops.backend.tosa.TABLE.default,
271+
args=(node.args[0], const_table_node),
272+
)
273+
output_node = table_op_node
274+
277275
if lshift != 0:
278276
scale = 2.0**lshift
279277
rescale_node = create_node(
280278
graph=graph_module.graph,
281-
op_target=torch.ops.tosa._rescale.default,
282-
args=(table_node, output_qparams[0].dtype, scale, 0, 0),
279+
op_target=exir_ops.backend.tosa.RESCALE.default,
280+
args=(table_op_node, output_qparams[0].dtype, scale, 0, 0),
283281
)
284282
output_node = rescale_node
285283

286284
node.replace_all_uses_with(output_node)
287-
288285
graph_module.graph.erase_node(node)
289-
290-
output_node.meta["input_qparams"] = input_qparams
291-
output_node.meta["output_qparams"] = output_qparams
286+
table_op_node.meta["input_qparams"] = input_qparams
287+
table_op_node.meta["output_qparams"] = output_qparams
292288
modified = True
293289

294290
if modified:

backends/arm/operators/op_rescale.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424

2525
@register_node_visitor
26-
class RescaleVisitor_INT(NodeVisitor):
27-
target = "_rescale.default"
26+
class RescaleVisitor(NodeVisitor):
27+
target = "tosa.RESCALE.default"
2828

2929
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+INT")]
3030

backends/arm/operators/op_table.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
@register_node_visitor
2525
class TableVisitor(NodeVisitor):
26-
target = "_table.default"
26+
target = "tosa.TABLE.default"
2727

2828
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+INT")]
2929

@@ -36,7 +36,7 @@ def define_node(
3636
) -> None:
3737
import serializer.tosa_serializer as ts # type: ignore
3838

39-
validate_num_inputs(self.target, inputs, 1)
39+
validate_num_inputs(self.target, inputs, 2)
4040
validate_valid_dtype(
4141
self.target, inputs, [ts.DType.INT8, ts.DType.INT16], output.tosa_spec
4242
)
@@ -45,12 +45,12 @@ def define_node(
4545
if inputs[0].dtype == ts.DType.INT16:
4646
validate_valid_dtype(self.target, output, ts.DType.INT32, output.tosa_spec)
4747

48-
if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr]
48+
if inputs[1].name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr]
4949
raise RuntimeError(
5050
f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}."
5151
)
5252

53-
table = self._exported_program.state_dict[node.name]
53+
table = self._exported_program.state_dict[inputs[1].name] # type: ignore[union-attr]
5454

5555
table_tensor_name = node.name + "_table"
5656
tosa_graph.addConst(

backends/arm/operators/op_transpose.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
@register_node_visitor
2525
class TransposeVisitor(NodeVisitor):
2626
"""
27-
This node visitor targets the _transpose op defined in the
28-
passthrough_to_tosa library. Used when switching between tosa_dim_orders.
27+
This node visitor targets the tosa::TRANSPOSE op defined in the
28+
TOSA backend dialect. Used when switching between tosa_dim_orders.
2929
Inserts a TOSA TRANSPOSE.
3030
"""
3131

32-
target = "_transpose.default"
32+
target = "tosa.TRANSPOSE.default"
3333

3434
tosa_specs = NodeVisitor.tosa_specs
3535

backends/arm/test/passes/test_insert_table_ops_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ def test_insert_table_tosa_INT(test_data: input_t):
3333
module,
3434
test_data,
3535
quantize=True,
36-
ops_before_pass={},
36+
ops_before_pass={"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
3737
ops_after_pass={
3838
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1,
3939
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
40-
"tosa._table": 1,
40+
"backend__ops_tosa_TABLE_default": 1,
4141
},
42-
ops_not_after_pass=["aten_sigmoid_default"],
42+
ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_sigmoid_default"],
4343
pass_list=[FoldAndAnnotateQParamsPass],
4444
passes_with_exported_program=[InsertTableOpsPass],
4545
)

0 commit comments

Comments
 (0)