Skip to content

Commit f9a5d46

Browse files
Add support for sum q/dq folding
sum is retraced to an int64 dtype of operator after q/dq folding. This patch adds a pass to manually force the dtype to be int8. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ifa737a398c5a878d52cd76a2392499905da085ce
1 parent 3b81be1 commit f9a5d46

File tree

3 files changed

+119
-56
lines changed

3 files changed

+119
-56
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
3333
FoldAndAnnotateQParamsPass,
3434
QuantizeFullArgument,
35+
RetraceFoldedDtypesPass,
3536
)
3637
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
3738
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
@@ -102,14 +103,16 @@ def transform_to_backend_pipeline(
102103
exir_ops.edge.aten.select_copy.int,
103104
exir_ops.edge.aten.sigmoid.default,
104105
exir_ops.edge.aten.sub.Tensor,
106+
exir_ops.edge.aten.sum.dim_IntList,
105107
exir_ops.edge.aten.tanh.default,
106108
]
107109
)
108110
)
111+
self.add_pass(RetraceFoldedDtypesPass())
109112
self.add_pass(InsertTableOpsPass(exported_program))
113+
self.add_pass(KeepDimsFalseToSqueezePass())
110114
self.add_pass(MatchArgRanksPass(exported_program))
111115
self.add_pass(DecomposeDivPass())
112-
self.add_pass(KeepDimsFalseToSqueezePass())
113116
self.add_pass(ConvertSplitToSlicePass())
114117
self.add_pass(Conv1dUnsqueezePass(exported_program))
115118
self.add_pass(DecomposeSoftmaxesPass())

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,33 @@ def call(self, graph_module: GraphModule) -> PassResult:
195195
modified = True
196196

197197
return PassResult(graph_module, modified)
198+
199+
200+
class RetraceFoldedDtypesPass(ExportPass):
201+
"""
202+
FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced
203+
some operators are retraced to types that cannot be handled by TOSA. One
204+
such example is sum.dim_IntList:
205+
q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ...
206+
After folding it becomes:
207+
q (int8) -> sum (int64) -> ...
208+
This pass changes types of ops in self.targeted_ops, such as sum, so that
209+
the output type of that matches the type of the output_qparams.
210+
"""
211+
212+
targeted_ops = {
213+
exir_ops.edge.aten.sum.dim_IntList,
214+
}
215+
216+
def call_operator(self, op, args, kwargs, meta):
217+
if op not in self.targeted_ops:
218+
return super().call_operator(op, args, kwargs, meta)
219+
220+
node_kwargs = kwargs.copy()
221+
output_qparams = meta["output_qparams"]
222+
if len(output_qparams) == 0:
223+
return super().call_operator(op, args, kwargs, meta)
224+
225+
output_dtype = output_qparams[0].dtype
226+
node_kwargs["dtype"] = output_dtype
227+
return super().call_operator(op, args, node_kwargs, meta)

backends/arm/operators/op_sum.py

Lines changed: 85 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@
1616
register_node_visitor,
1717
)
1818
from executorch.backends.arm.tosa_mapping import TosaArg
19+
from executorch.backends.arm.tosa_specification import TosaSpecification
1920
from serializer.tosa_serializer import TosaOp
2021
from torch.fx import Node
2122

2223

2324
@register_node_visitor
24-
class AddVisitor(NodeVisitor):
25+
class SumVisitor_080_BI(NodeVisitor):
2526
target = "aten.sum.dim_IntList"
2627

28+
tosa_specs = [
29+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
30+
]
31+
2732
def __init__(self, *args):
2833
super().__init__(*args)
2934

@@ -35,64 +40,89 @@ def define_node(
3540
output: TosaArg,
3641
is_quant_node: bool,
3742
) -> None:
38-
input_node = inputs[0]
39-
input_shape = list(input_node.shape)
43+
input_shape = list(inputs[0].shape)
4044
dim_list = cast(list[int], inputs[1].special)
41-
dim_list = [dim % len(input_node.shape) for dim in dim_list]
45+
dim_list = [dim % len(input_shape) for dim in dim_list]
4246
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
4347
assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass"
4448

45-
if is_quant_node:
49+
# Rescale input to 32 bit
50+
rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32(
51+
tosa_graph,
52+
[inputs[0]],
53+
node,
54+
)
55+
56+
prev_node = rescaled_inputs[0]
57+
reduced_shape = input_shape
58+
59+
# Reduce all dims in dim_list one-by-one.
60+
for dim in dim_list:
61+
# When reduced, the size of the dim becomes 1.
62+
reduced_shape[dim] = 1
63+
64+
attr = ts.TosaSerializerAttribute()
65+
attr.AxisAttribute(inputs[0].dim_order.index(dim))
66+
67+
next_node = tosa_graph.addIntermediate(
68+
tutils.tosa_shape(reduced_shape, inputs[0].dim_order),
69+
dtype=ts.DType.INT32,
70+
)
71+
72+
tosa_graph.addOperator(
73+
TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr
74+
)
75+
76+
prev_node = next_node
77+
tqutils.insert_rescale_op_to_int8(tosa_graph, prev_node, scale, node)
78+
79+
80+
@register_node_visitor
81+
class SumVisitor_080_MI(SumVisitor_080_BI):
82+
# inheriting 'target' from BI class
83+
84+
tosa_specs = [
85+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
86+
]
87+
88+
def __init__(self, *args):
89+
super().__init__(*args)
90+
91+
def define_node(
92+
self,
93+
node: Node,
94+
tosa_graph: ts.TosaSerializer,
95+
inputs: List[TosaArg],
96+
output: TosaArg,
97+
is_quant_node: bool,
98+
) -> None:
99+
if inputs[0].dtype == ts.DType.INT8:
100+
return super().define_node(node, tosa_graph, inputs, output, is_quant_node)
101+
input_name = inputs[0].name
102+
reduced_shape = list(inputs[0].shape)
103+
dim_list = cast(list[int], inputs[1].special)
104+
dim_list = [dim % len(reduced_shape) for dim in dim_list]
105+
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
106+
assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass"
107+
108+
# Reduce all dims in dim_list one-by-one.
109+
for dim in dim_list:
110+
# When reduced, the size of the dim becomes 1
111+
reduced_shape[dim] = 1
112+
113+
attr = ts.TosaSerializerAttribute()
114+
attr.AxisAttribute(inputs[0].dim_order.index(dim))
115+
116+
if dim == dim_list[-1]:
117+
output_name = output.name
118+
else:
119+
output_name = tosa_graph.addIntermediate(
120+
tutils.tosa_shape(reduced_shape, inputs[0].dim_order),
121+
dtype=ts.DType.FP32,
122+
).name
46123

47-
# Rescale input to 32 bit
48-
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
49-
[node.all_input_nodes[0]], tosa_graph
124+
tosa_graph.addOperator(
125+
TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr
50126
)
51127

52-
prev_node = rescaled_inputs[0]
53-
reduced_shape = input_shape
54-
55-
# Reduce all dims in dim_list one-by-one.
56-
for dim in dim_list:
57-
# When reduced, the size of the dim becomes 1.
58-
reduced_shape[dim] = 1
59-
60-
attr = ts.TosaSerializerAttribute()
61-
attr.AxisAttribute(input_node.dim_order.index(dim))
62-
63-
next_node = tosa_graph.addIntermediate(
64-
tutils.tosa_shape(reduced_shape, input_node.dim_order),
65-
dtype=ts.DType.INT32,
66-
)
67-
68-
tosa_graph.addOperator(
69-
TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr
70-
)
71-
72-
prev_node = next_node
73-
tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph)
74-
else:
75-
input_name = input_node.name
76-
reduced_shape = input_shape
77-
78-
# Reduce all dims in dim_list one-by-one.
79-
for dim in dim_list:
80-
# When reduced, the size of the dim becomes 1
81-
reduced_shape[dim] = 1
82-
83-
attr = ts.TosaSerializerAttribute()
84-
attr.AxisAttribute(input_node.dim_order.index(dim))
85-
86-
if dim == dim_list[-1]:
87-
output_name = output.name
88-
else:
89-
output_name = tosa_graph.addIntermediate(
90-
tutils.tosa_shape(reduced_shape, input_node.dim_order),
91-
dtype=ts.DType.FP32,
92-
).name
93-
94-
tosa_graph.addOperator(
95-
TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr
96-
)
97-
98-
input_name = output_name
128+
input_name = output_name

0 commit comments

Comments
 (0)