Skip to content

Commit 3b81be1

Browse files
Increase q/dq folding coverage
Add support for q/dq folding of more operators such as hardtanh, maxpool2d, mul, relu, select, sub, to_copy. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ifdabda4c927dade41c000859054696844c546f7b
1 parent 1f04bad commit 3b81be1

File tree

11 files changed

+181
-161
lines changed

11 files changed

+181
-161
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,6 @@ def transform_to_backend_pipeline(
7979
self.add_pass(DecomposeVarPass())
8080
self.add_pass(ConvertMeanDimToAveragePool())
8181
self.add_pass(DecomposeMeanDimPass())
82-
self.add_pass(MatchArgRanksPass(exported_program))
83-
self.add_pass(DecomposeDivPass())
84-
self.add_pass(KeepDimsFalseToSqueezePass())
85-
self.add_pass(ConvertSplitToSlicePass())
86-
self.add_pass(Conv1dUnsqueezePass(exported_program))
87-
self.add_pass(DecomposeSoftmaxesPass())
8882
self.add_pass(DecomposeLinearPass())
8983
self.add_pass(QuantizeFullArgument())
9084
self.add_pass(
@@ -96,17 +90,29 @@ def transform_to_backend_pipeline(
9690
exir_ops.edge.aten.convolution.default,
9791
exir_ops.edge.aten.exp.default,
9892
exir_ops.edge.aten.full.default,
93+
exir_ops.edge.aten.hardtanh.default,
9994
exir_ops.edge.aten.log.default,
95+
exir_ops.edge.aten.max_pool2d.default,
10096
exir_ops.edge.aten.maximum.default,
10197
exir_ops.edge.aten.minimum.default,
98+
exir_ops.edge.aten.mul.Tensor,
10299
exir_ops.edge.aten.reciprocal.default,
100+
exir_ops.edge.aten.relu.default,
103101
exir_ops.edge.aten.rsqrt.default,
102+
exir_ops.edge.aten.select_copy.int,
104103
exir_ops.edge.aten.sigmoid.default,
104+
exir_ops.edge.aten.sub.Tensor,
105105
exir_ops.edge.aten.tanh.default,
106106
]
107107
)
108108
)
109109
self.add_pass(InsertTableOpsPass(exported_program))
110+
self.add_pass(MatchArgRanksPass(exported_program))
111+
self.add_pass(DecomposeDivPass())
112+
self.add_pass(KeepDimsFalseToSqueezePass())
113+
self.add_pass(ConvertSplitToSlicePass())
114+
self.add_pass(Conv1dUnsqueezePass(exported_program))
115+
self.add_pass(DecomposeSoftmaxesPass())
110116
for spec in compile_spec:
111117
if spec.key == "permute_memory_format":
112118
memory_format = spec.value.decode()

backends/arm/operators/op_div.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

backends/arm/operators/op_hardtanh.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
import serializer.tosa_serializer as ts
1010
import torch
11+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
12+
get_input_qparams,
13+
)
1114
from executorch.backends.arm.operators.node_visitor import (
1215
NodeVisitor,
1316
register_node_visitor,
1417
)
1518
from executorch.backends.arm.tosa_mapping import TosaArg
1619

17-
from executorch.backends.arm.tosa_quant_utils import (
18-
get_quant_arg_upstream,
19-
quantize_value,
20-
)
20+
from executorch.backends.arm.tosa_quant_utils import quantize_value
2121
from serializer.tosa_serializer import TosaOp
2222

2323

@@ -38,9 +38,10 @@ def define_node(
3838
) -> None:
3939
attr = ts.TosaSerializerAttribute()
4040

41-
if is_quant_node:
41+
if inputs[0].dtype == ts.DType.INT8:
4242
# Get quant parameters
43-
qargs = get_quant_arg_upstream(node.all_input_nodes[0])
43+
input_qparams = get_input_qparams(node)
44+
qargs = input_qparams[0]
4445
# Convert to quantized representation
4546
clamp_min_qs = quantize_value(inputs[1].number, qargs)
4647
clamp_max_qs = quantize_value(inputs[2].number, qargs)

backends/arm/operators/op_max_pool2d.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@
88

99
import serializer.tosa_serializer as ts
1010
import torch
11+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
12+
get_input_qparams,
13+
get_output_qparams,
14+
)
1115
from executorch.backends.arm.operators.node_visitor import (
1216
NodeVisitor,
1317
register_node_visitor,
1418
)
1519
from executorch.backends.arm.tosa_mapping import TosaArg
16-
from executorch.backends.arm.tosa_quant_utils import (
17-
get_quant_arg_downstream,
18-
get_quant_arg_upstream,
19-
)
20-
2120
from serializer.tosa_serializer import TosaOp
2221

2322

@@ -46,19 +45,18 @@ def define_node(
4645
except IndexError:
4746
padding = [0, 0, 0, 0]
4847

49-
accumulator_type = input_tensor.dtype
50-
51-
if is_quant_node:
52-
# Accumulator type always is int8 when input tensor is an integer type.
53-
accumulator_type = ts.DType.INT8
48+
accumulator_type = output.dtype
5449

5550
# Initilize zero point to zero.
5651
input_zp = 0
57-
output_zp = 0
52+
if inputs[0].dtype == ts.DType.INT8:
53+
input_qparams = get_input_qparams(node)
54+
input_zp = input_qparams[0].zp
5855

59-
if is_quant_node:
60-
input_zp = get_quant_arg_upstream(node.all_input_nodes[0]).zp
61-
output_zp = get_quant_arg_downstream(list(node.users)[0]).zp
56+
output_zp = 0
57+
if output.dtype == ts.DType.INT8:
58+
output_qparams = get_output_qparams(node)
59+
output_zp = output_qparams[0].zp
6260

6361
attr = ts.TosaSerializerAttribute()
6462
attr.PoolAttribute(

backends/arm/operators/op_mul.py

Lines changed: 74 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,34 @@
55

66
# pyre-unsafe
77

8-
from typing import cast, List
8+
from typing import List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111
import executorch.backends.arm.tosa_utils as tutils
1212

1313
import serializer.tosa_serializer as ts
1414
import torch
15+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
16+
get_input_qparams,
17+
)
1518

1619
from executorch.backends.arm.operators.node_visitor import (
1720
NodeVisitor,
1821
register_node_visitor,
1922
)
2023
from executorch.backends.arm.tosa_mapping import TosaArg
24+
from executorch.backends.arm.tosa_specification import TosaSpecification
2125
from serializer.tosa_serializer import TosaOp
2226

2327

2428
@register_node_visitor
25-
class MulVisitor(NodeVisitor):
29+
class MulVisitor_080_BI(NodeVisitor):
2630
target = "aten.mul.Tensor"
2731

32+
tosa_specs = [
33+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
34+
]
35+
2836
def define_node(
2937
self,
3038
node: torch.fx.Node,
@@ -33,57 +41,68 @@ def define_node(
3341
output: TosaArg,
3442
is_quant_node: bool,
3543
) -> None:
44+
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8
45+
input_A = inputs[0]
46+
input_B = inputs[1]
47+
input_qparams = get_input_qparams(node)
48+
input_A_qargs = input_qparams[0]
49+
input_B_qargs = input_qparams[1]
50+
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
51+
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
52+
53+
# Rescale inputs to INT32 with zp=0
54+
input_A_rescaled = tqutils.build_rescale_to_int32(
55+
tosa_graph,
56+
input_A,
57+
input_A_qargs.zp,
58+
rescale_scale=1.0,
59+
)
60+
input_B_rescaled = tqutils.build_rescale_to_int32(
61+
tosa_graph,
62+
input_B,
63+
input_B_qargs.zp,
64+
rescale_scale=1.0,
65+
)
66+
67+
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
68+
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
69+
70+
# Do the INT32 Mul
71+
attr = ts.TosaSerializerAttribute()
72+
attr.MulAttribute(shift=0)
73+
tosa_graph.addOperator(
74+
TosaOp.Op().MUL,
75+
[
76+
input_A_rescaled.name,
77+
input_B_rescaled.name,
78+
],
79+
[mul_output.name],
80+
attr,
81+
)
82+
output_scale = input_A_qargs.scale * input_B_qargs.scale
83+
tqutils.insert_rescale_op_to_int8(tosa_graph, mul_output, output_scale, node)
84+
85+
86+
@register_node_visitor
87+
class MulVisitor_080_MI(MulVisitor_080_BI):
88+
# inheriting 'target' from BI class
89+
90+
tosa_specs = [
91+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
92+
]
3693

37-
if is_quant_node:
38-
input_A = inputs[0]
39-
input_B = inputs[1]
40-
input_A_qargs = tqutils.get_quant_arg_upstream(
41-
cast(torch.fx.Node, node.args[0])
42-
)
43-
input_B_qargs = tqutils.get_quant_arg_upstream(
44-
cast(torch.fx.Node, node.args[1])
45-
)
46-
47-
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
48-
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
49-
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
50-
51-
# Rescale inputs to INT32 with zp=0
52-
input_A_rescaled = tqutils.build_rescale_to_int32(
53-
tosa_graph,
54-
input_A,
55-
input_A_qargs.zp,
56-
rescale_scale=1.0,
57-
)
58-
input_B_rescaled = tqutils.build_rescale_to_int32(
59-
tosa_graph,
60-
input_B,
61-
input_B_qargs.zp,
62-
rescale_scale=1.0,
63-
)
64-
65-
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
66-
67-
# Do the INT32 Mul
68-
attr = ts.TosaSerializerAttribute()
69-
attr.MulAttribute(shift=0)
70-
tosa_graph.addOperator(
71-
TosaOp.Op().MUL,
72-
[
73-
input_A_rescaled.name,
74-
input_B_rescaled.name,
75-
],
76-
[mul_output.name],
77-
attr,
78-
)
79-
80-
tqutils.rescale_node_back_to_int8(
81-
node, mul_output, input_A_qargs.scale * input_B_qargs.scale, tosa_graph
82-
)
83-
84-
else:
85-
attr = ts.TosaSerializerAttribute()
86-
attr.MulAttribute(shift=0)
87-
tosa_graph.addOperator(
88-
TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr
89-
)
94+
def define_node(
95+
self,
96+
node: torch.fx.Node,
97+
tosa_graph: ts.TosaSerializer,
98+
inputs: List[TosaArg],
99+
output: TosaArg,
100+
is_quant_node: bool,
101+
) -> None:
102+
if inputs[0].dtype == ts.DType.INT8:
103+
return super().define_node(node, tosa_graph, inputs, output, is_quant_node)
104+
attr = ts.TosaSerializerAttribute()
105+
attr.MulAttribute(shift=0)
106+
tosa_graph.addOperator(
107+
TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr
108+
)

backends/arm/operators/op_relu.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import executorch.backends.arm.tosa_quant_utils as tqutils
99
import serializer.tosa_serializer as ts
1010
import torch.fx
11+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
12+
get_output_qparams,
13+
)
1114
from executorch.backends.arm.operators.node_visitor import (
1215
NodeVisitor,
1316
register_node_visitor,
@@ -37,10 +40,10 @@ def define_node(
3740
clamp_max_fp = 0.0
3841
clamp_min_qs = 0
3942
clamp_max_qs = 0
40-
if is_quant_node:
41-
out_qargs = tqutils.get_quant_arg_downstream(list(node.users)[0])
42-
clamp_min_qs = tqutils.quantize_value(0, out_qargs)
43-
clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs)
43+
if inputs[0].dtype == ts.DType.INT8:
44+
out_qargs = get_output_qparams(node)
45+
clamp_min_qs = tqutils.quantize_value(0, out_qargs[0])
46+
clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs[0])
4447

4548
else:
4649
clamp_min_fp = 0

backends/arm/operators/op_select.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ def define_node(
5050
expanded_shape = tuple(1 if i == dim else shape[i] for i in range(rank))
5151
expanded_shape = tosa_shape(expanded_shape, input_node.dim_order)
5252

53-
output_reshaped = tosa_graph.addIntermediate(
54-
expanded_shape, ts.DType.INT8 if is_quant_node else output.dtype
55-
)
53+
output_reshaped = tosa_graph.addIntermediate(expanded_shape, output.dtype)
5654

5755
attr_slice = ts.TosaSerializerAttribute()
5856

0 commit comments

Comments
 (0)