Skip to content

Commit 45c1976

Browse files
Add TOSA table as custom edge op
Edge operators that are lowered to TOSA TABLEs are convereted to a custom edge IR table-op. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I147008c30b9b46c7b8ae1a1c15bc540fea614a69
1 parent 5190106 commit 45c1976

File tree

10 files changed

+212
-291
lines changed

10 files changed

+212
-291
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
FoldAndAnnotateQParamsPass,
3434
QuantizeFullArgument,
3535
)
36+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
3637
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3738
KeepDimsFalseToSqueezePass,
3839
)
@@ -94,10 +95,17 @@ def transform_to_backend_pipeline(
9495
exir_ops.edge.aten.add.Tensor,
9596
exir_ops.edge.aten.avg_pool2d.default,
9697
exir_ops.edge.aten.convolution.default,
98+
exir_ops.edge.aten.exp.default,
9799
exir_ops.edge.aten.full.default,
100+
exir_ops.edge.aten.log.default,
101+
exir_ops.edge.aten.reciprocal.default,
102+
exir_ops.edge.aten.rsqrt.default,
103+
exir_ops.edge.aten.sigmoid.default,
104+
exir_ops.edge.aten.tanh.default,
98105
]
99106
)
100107
)
108+
self.add_pass(InsertTableOpsPass(exported_program))
101109
for spec in compile_spec:
102110
if spec.key == "permute_memory_format":
103111
memory_format = spec.value.decode()
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
12+
from executorch.exir import ExportedProgram
13+
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
from executorch.exir.pass_base import ExportPass, PassResult
17+
from torch.fx import GraphModule
18+
from torch.library import impl, Library
19+
20+
lib = Library("tosa", "DEF")
21+
lib.define("_table(Tensor self) -> Tensor")
22+
23+
24+
@impl(lib, "_table")
25+
def _table_impl(*args, **kwargs):
26+
return args[0]
27+
28+
29+
class InsertTableOpsPass(ExportPass):
30+
"""
31+
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
32+
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
33+
When loweringthe _table node target_str will be used to find the corresponding torch operator
34+
which will be used to produce the table values in operators/op_table.py.
35+
"""
36+
37+
table_ops = {
38+
exir_ops.edge.aten.exp.default: torch.exp,
39+
exir_ops.edge.aten.log.default: torch.log,
40+
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
41+
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
42+
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
43+
exir_ops.edge.aten.tanh.default: torch.tanh,
44+
}
45+
46+
def __init__(self, exported_program: ExportedProgram):
47+
super().__init__()
48+
self.exported_program = exported_program
49+
50+
def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
51+
"""
52+
Add buffer to self.exported_program.state_dict
53+
"""
54+
self.exported_program.state_dict[buffer_name] = buffer
55+
56+
def generate_table_values(
57+
self,
58+
torch_op: Callable[[torch.Tensor], torch.Tensor],
59+
in_quantargs: QuantArgs,
60+
out_quantargs: QuantArgs,
61+
) -> torch.Tensor:
62+
def f(x: torch.Tensor) -> torch.Tensor:
63+
x = in_quantargs.dequantize_value(x)
64+
x = torch_op(x)
65+
return out_quantargs.quantize_value(x)
66+
67+
input_dtype = in_quantargs.dtype
68+
steps = in_quantargs.qmax - in_quantargs.qmin + 1
69+
return f(
70+
torch.linspace(
71+
start=in_quantargs.qmin,
72+
end=in_quantargs.qmax,
73+
steps=steps,
74+
# use torch.int64 to avoid overflow when dequantizing (subtracting zp).
75+
# e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
76+
dtype=torch.int64,
77+
)
78+
).to(dtype=input_dtype)
79+
80+
def call(self, graph_module: GraphModule) -> PassResult:
81+
modified = False
82+
for node in graph_module.graph.nodes:
83+
if node.op != "call_function" or node.target not in self.table_ops:
84+
continue
85+
input_qparams = node.meta["input_qparams"]
86+
output_qparams = node.meta["output_qparams"]
87+
if len(input_qparams) == 0 or len(output_qparams) == 0:
88+
# We only want to replace the node if it's quantized
89+
continue
90+
# Create table node
91+
with graph_module.graph.inserting_before(node):
92+
table_node = create_node(
93+
graph=graph_module.graph,
94+
op_target=torch.ops.tosa._table,
95+
args=(node.args[0],),
96+
)
97+
assert len(input_qparams) == 1
98+
assert len(output_qparams) == 1
99+
# Generate table buffer
100+
buffer = self.generate_table_values(
101+
torch_op=self.table_ops[node.target],
102+
in_quantargs=input_qparams[0],
103+
out_quantargs=output_qparams[0],
104+
)
105+
# Register buffer in self.exported_program.state_dict
106+
self.register_buffer(buffer_name=table_node.name, buffer=buffer)
107+
node.replace_all_uses_with(table_node)
108+
graph_module.graph.erase_node(node)
109+
table_node.meta["input_qparams"] = input_qparams
110+
table_node.meta["output_qparams"] = output_qparams
111+
modified = True
112+
113+
if modified:
114+
# retrace the graph to update the fake tensor types
115+
graph_module = super().call(graph_module).graph_module
116+
117+
graph_module.recompile()
118+
return PassResult(graph_module, modified)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
op_squeeze,
3838
op_sub,
3939
op_sum,
40+
op_table,
4041
op_tanh,
4142
op_to_copy,
4243
op_transpose,

backends/arm/operators/op_exp.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,25 @@
66
# pyre-unsafe
77
from typing import List
88

9-
import numpy as np
10-
119
import serializer.tosa_serializer as ts
1210
from executorch.backends.arm.operators.node_visitor import (
1311
NodeVisitor,
1412
register_node_visitor,
1513
)
1614
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
1716

18-
from executorch.backends.arm.tosa_quant_utils import (
19-
dequantize_value,
20-
get_quant_arg_downstream,
21-
get_quant_arg_upstream,
22-
QuantArgs,
23-
quantize_value,
24-
)
2517
from serializer.tosa_serializer import TosaOp
2618
from torch.fx import Node
2719

2820

2921
@register_node_visitor
30-
class ExpVisitor(NodeVisitor):
22+
class ExpVisitor_0_80_MI(NodeVisitor):
3123
target = "aten.exp.default"
3224

25+
# BI case should be handled by op_table
26+
tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")]
27+
3328
def __init__(self, *args):
3429
super().__init__(*args)
3530

@@ -43,41 +38,6 @@ def define_node(
4338
) -> None:
4439

4540
assert len(node.all_input_nodes) == 1
41+
assert inputs[0].dtype == output.dtype == ts.DType.FP32
4642

47-
if is_quant_node:
48-
# Assume quantized input is 8 bit.
49-
50-
# Create attribute for 8 bit table lookup.
51-
input_node = node.all_input_nodes[0]
52-
in_quantargs = get_quant_arg_upstream(input_node)
53-
output_node = list(node.users)[0]
54-
out_quantargs = get_quant_arg_downstream(output_node)
55-
56-
table = exp_table_8bit(in_quantargs, out_quantargs)
57-
table_attr = ts.TosaSerializerAttribute()
58-
table_attr.TableAttribute(table)
59-
60-
tosa_graph.addOperator(
61-
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
62-
)
63-
else:
64-
tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name])
65-
66-
67-
def exp_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
68-
"""
69-
Returns a table mapping 256 entries to exp([qmin,qmax])
70-
"""
71-
72-
def exp(x):
73-
# Convert quantized input to floating point exp input space.
74-
v = dequantize_value(x, in_quantargs)
75-
# Compute exp.
76-
v = np.exp(v)
77-
# Convert exp output back to quantized space.
78-
return quantize_value(v, out_quantargs)
79-
80-
return [
81-
exp(x)
82-
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
83-
]
43+
tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name])

backends/arm/operators/op_log.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,14 @@
66
# pyre-unsafe
77
from typing import List
88

9-
import numpy as np
10-
119
import serializer.tosa_serializer as ts
1210
from executorch.backends.arm.operators.node_visitor import (
1311
NodeVisitor,
1412
register_node_visitor,
1513
)
1614
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
1716

18-
from executorch.backends.arm.tosa_quant_utils import (
19-
dequantize_value,
20-
get_quant_arg_downstream,
21-
get_quant_arg_upstream,
22-
QuantArgs,
23-
quantize_value,
24-
)
2517
from serializer.tosa_serializer import TosaOp
2618
from torch.fx import Node
2719

@@ -30,6 +22,9 @@
3022
class LogVisitor(NodeVisitor):
3123
target = "aten.log.default"
3224

25+
# BI case should be handled by op_table
26+
tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")]
27+
3328
def __init__(self, *args):
3429
super().__init__(*args)
3530

@@ -41,44 +36,8 @@ def define_node(
4136
output: TosaArg,
4237
is_quant_node: bool,
4338
) -> None:
44-
4539
assert len(node.all_input_nodes) == 1
4640
assert len(node.users) == 1
41+
assert inputs[0].dtype == output.dtype == ts.DType.FP32
4742

48-
if is_quant_node:
49-
# Assume quantized input is 8 bit.
50-
51-
# Create attribute for 8 bit table lookup.
52-
input_node = node.all_input_nodes[0]
53-
in_quantargs = get_quant_arg_upstream(input_node)
54-
output_node = list(node.users)[0]
55-
out_quantargs = get_quant_arg_downstream(output_node)
56-
57-
table = log_table_8bit(in_quantargs, out_quantargs)
58-
table_attr = ts.TosaSerializerAttribute()
59-
table_attr.TableAttribute(table)
60-
61-
tosa_graph.addOperator(
62-
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
63-
)
64-
else:
65-
tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name])
66-
67-
68-
def log_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
69-
"""
70-
Returns a table mapping 256 entries to log([qmin,qmax])
71-
"""
72-
73-
def log(x):
74-
# Convert quantized input to floating point log input space.
75-
v = dequantize_value(x, in_quantargs)
76-
# Compute log.
77-
v = np.log(v)
78-
# Convert log output back to quantized space.
79-
return quantize_value(v, out_quantargs)
80-
81-
return [
82-
log(x)
83-
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
84-
]
43+
tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name])

backends/arm/operators/op_reciprocal.py

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,24 @@
66
# pyre-unsafe
77
from typing import List
88

9-
import numpy as np
10-
119
import serializer.tosa_serializer as ts
1210
import torch
1311
from executorch.backends.arm.operators.node_visitor import (
1412
NodeVisitor,
1513
register_node_visitor,
1614
)
1715
from executorch.backends.arm.tosa_mapping import TosaArg
18-
from executorch.backends.arm.tosa_quant_utils import (
19-
dequantize_value,
20-
get_quant_arg_downstream,
21-
get_quant_arg_upstream,
22-
QuantArgs,
23-
quantize_value,
24-
)
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
2517
from serializer.tosa_serializer import TosaOp
2618

2719

2820
@register_node_visitor
29-
class DivVisitor(NodeVisitor):
21+
class ReciprocalVisitor_080_MI(NodeVisitor):
3022
target = "aten.reciprocal.default"
3123

24+
# BI case should be handled by op_table
25+
tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")]
26+
3227
def __init__(self, *args):
3328
super().__init__(*args)
3429

@@ -40,43 +35,5 @@ def define_node(
4035
output: TosaArg,
4136
is_quant_node: bool,
4237
) -> None:
43-
# 1/X
44-
45-
if is_quant_node:
46-
input = inputs[0]
47-
input_qargs = get_quant_arg_upstream(node.all_input_nodes[0])
48-
output_qargs = get_quant_arg_downstream(list(node.users)[0])
49-
50-
div_table = div_table_8bit(input_qargs, output_qargs)
51-
52-
table_attr = ts.TosaSerializerAttribute()
53-
table_attr.TableAttribute(div_table)
54-
tosa_graph.addOperator(
55-
TosaOp.Op().TABLE, [input.name], [output.name], table_attr
56-
)
57-
58-
else:
59-
tosa_graph.addOperator(
60-
TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
61-
)
62-
63-
64-
def div_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
65-
"""
66-
Returns a table mapping 256 entries to div([qmin,qmax])
67-
"""
68-
69-
def div(x):
70-
# Convert quantized input to floating point div input space.
71-
v1 = dequantize_value(x, in_quantargs)
72-
# Compute div.
73-
v2 = 1.0 / v1
74-
# Convert div output back to quantized space.
75-
v3 = quantize_value(v2, out_quantargs)
76-
77-
return v3
78-
79-
return [
80-
div(x)
81-
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
82-
]
38+
assert inputs[0].dtype == output.dtype == ts.DType.FP32
39+
tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name])

0 commit comments

Comments
 (0)