Skip to content

Commit d7e560a

Browse files
author
pytorchbot
committed
2024-12-16 nightly release (99d5b80)
1 parent 8f1099e commit d7e560a

34 files changed

+2575
-42
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
DecomposeSoftmaxesPass,
3030
)
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
33+
FoldAndAnnotateQParamsPass,
34+
)
3235
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3336
KeepDimsFalseToSqueezePass,
3437
)
@@ -50,6 +53,7 @@
5053
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5154
from executorch.exir import ExportedProgram
5255
from executorch.exir.backend.compile_spec_schema import CompileSpec
56+
from executorch.exir.dialects._ops import ops as exir_ops
5357
from executorch.exir.pass_manager import PassManager
5458

5559

@@ -80,6 +84,15 @@ def transform_to_backend_pipeline(
8084
self.add_pass(Conv1dUnsqueezePass(exported_program))
8185
self.add_pass(DecomposeSoftmaxesPass())
8286
self.add_pass(DecomposeLinearPass())
87+
self.add_pass(
88+
FoldAndAnnotateQParamsPass(
89+
[
90+
exir_ops.edge.aten.minimum.default,
91+
exir_ops.edge.aten.maximum.default,
92+
exir_ops.edge.aten.add.Tensor,
93+
]
94+
)
95+
)
8396
for spec in compile_spec:
8497
if spec.key == "permute_memory_format":
8598
memory_format = spec.value.decode()

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ def call_operator(self, op, args, kwargs, meta):
3636
]
3737

3838
# To convert expand arg to repeat arg, non-repeated dims should have
39-
# multiples[dim] = 1.
39+
# multiples[dim] = 1. Passing -1 to expand arg means
40+
# not changing the size of that dimension.
4041
multiples = [
41-
multiples[i] if extended_shape[i] == 1 else 1 for i in range(expanded_rank)
42+
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
43+
for i in range(expanded_rank)
4244
]
4345
return super().call_operator(
4446
op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
from typing import Callable, cast, Iterable
10+
11+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
from torch.fx import GraphModule, Node
17+
18+
19+
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
20+
"""
21+
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
22+
Raises a ValueError if the node doesn't have any parameters set.
23+
"""
24+
if "input_qparams" not in node.meta.keys():
25+
raise ValueError(f"No input quantization parameter found in node {node}")
26+
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
27+
if len(input_qparams) == 0:
28+
raise ValueError(f"No input quantization parameter found in node {node}")
29+
return input_qparams
30+
31+
32+
def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
33+
"""
34+
Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
35+
Raises a ValueError if the node doesn't have any parameters set.
36+
"""
37+
if "output_qparams" not in node.meta.keys():
38+
raise ValueError(f"No output quantization parameter found in node {node}")
39+
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
40+
if len(input_qparams) == 0:
41+
raise ValueError(f"No output quantization parameter found in node {node}")
42+
return input_qparams
43+
44+
45+
class FoldAndAnnotateQParamsPass(ExportPass):
46+
"""
47+
A pass that walks the graph and removes any DQ and Q nodes before and after the target
48+
node in the supplied list of operators.
49+
The quantization parameters from the DQ/Q nodes are stored as meta values to be
50+
accessible for later lowering and serialization passes.
51+
The assumption is that the quantization annotatation adds DQ nodes for all tensor
52+
inputs to the target one Q node to the output.
53+
54+
Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):
55+
56+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
57+
58+
x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8)
59+
aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq)
60+
aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8)
61+
62+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
63+
64+
Becomes:
65+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
66+
67+
aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q)
68+
69+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
70+
71+
The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.
72+
73+
"""
74+
75+
def __init__(self, targeted_ops: Iterable[Callable]):
76+
super().__init__()
77+
self.targeted_ops = targeted_ops
78+
79+
def call(self, graph_module: GraphModule) -> PassResult:
80+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
81+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
82+
83+
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
84+
for n in graph_module.graph.nodes:
85+
n = cast(Node, n)
86+
if n.op != "call_function" or n.target not in self.targeted_ops:
87+
continue
88+
89+
# Make sure we haven't already set qparams meta information on the node
90+
assert "input_qparams" not in n.meta.keys()
91+
assert "output_qparams" not in n.meta.keys()
92+
93+
# for the inputs and outputs search the graph for quantization info and
94+
# store the information in a dict with order of the _tensor_ inputs as key,
95+
# ignoring any other arguments to the target node.
96+
n.meta["input_qparams"] = {}
97+
n.meta["output_qparams"] = {}
98+
for i, arg in enumerate(n.args):
99+
if not isinstance(arg, Node):
100+
continue
101+
if arg.target != dq_op:
102+
continue
103+
104+
# arg.target for argument i is a dequant node, extract the information
105+
n.meta["input_qparams"][i] = QuantArgs.from_operator(
106+
arg.target, arg.args
107+
)
108+
109+
# arg.args[0] is the tensor input, replace the input usage
110+
n.replace_input_with(arg, arg.args[0])
111+
graph_module.graph.erase_node(arg)
112+
113+
# Copy the users, since we are modifying it.
114+
users_copy = copy.copy(n.users)
115+
for i, user in enumerate(users_copy):
116+
if user.target != q_op:
117+
continue
118+
119+
# quantization node found here, store the quantization parameters in meta value
120+
n.meta["output_qparams"][i] = QuantArgs.from_operator(
121+
user.target, user.args
122+
)
123+
124+
user.replace_all_uses_with(n)
125+
graph_module.graph.erase_node(user)
126+
127+
# retrace the graph to update the fake tensor types
128+
graph_module = super().call(graph_module).graph_module
129+
130+
graph_module.recompile()
131+
return PassResult(graph_module, True)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
9494
exir_ops.edge.aten.sigmoid.default,
9595
exir_ops.edge.aten.mean.dim,
9696
exir_ops.edge.aten.mm.default,
97+
exir_ops.edge.aten.minimum.default,
98+
exir_ops.edge.aten.maximum.default,
9799
exir_ops.edge.aten.repeat.default,
98100
exir_ops.edge.aten.reciprocal.default,
99101
exir_ops.edge.aten.relu.default,

backends/arm/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
op_get_item,
2020
op_hardtanh,
2121
op_log,
22+
op_max,
2223
op_max_pool2d,
24+
op_min,
2325
op_mm,
2426
op_mul,
2527
op_permute,

backends/arm/operators/op_add.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import executorch.backends.arm.tosa_utils as tutils
1212

1313
import serializer.tosa_serializer as ts
14-
import torch
1514
from executorch.backends.arm.operators.node_visitor import (
1615
NodeVisitor,
1716
register_node_visitor,
@@ -41,33 +40,27 @@ def define_node(
4140
output: TosaArg,
4241
is_quant_node: bool,
4342
) -> None:
44-
input_nodes = tutils.get_two_inputs(node)
45-
46-
if not is_quant_node and not all(
47-
tensor.meta["val"].dtype in (torch.int8, torch.int32)
48-
for tensor in input_nodes
49-
):
50-
raise RuntimeError(
51-
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
52-
)
53-
54-
needs_rescale = not (
55-
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
56-
and node.meta["val"].dtype == torch.int32
57-
)
58-
59-
if needs_rescale:
60-
# Rescale inputs to 32 bit
61-
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
62-
input_nodes, tosa_graph
43+
# Specification (0.80.0) states that input and output types
44+
# should all be the same
45+
assert inputs[0].dtype == inputs[1].dtype == output.dtype
46+
# Handle int8 (quantized) and int32
47+
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
48+
49+
if inputs[0].dtype == ts.DType.INT8:
50+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
51+
tosa_graph, inputs, node
6352
)
53+
else:
54+
# input[0].dtype == ts.DType.INT32
55+
# Non quantized input, natively support by TOSA.ADD
56+
rescaled_inputs = inputs
6457

65-
# Prepare add output tensor
58+
if output.dtype == ts.DType.INT8:
6659
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6760
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
6861
else:
62+
# output.dtype == ts.DType.INT32
6963
add_output = output
70-
rescaled_inputs = inputs
7164

7265
# Do the INT32 Add
7366
tosa_graph.addOperator(
@@ -80,10 +73,10 @@ def define_node(
8073
None,
8174
)
8275

83-
if needs_rescale:
76+
if output.dtype == ts.DType.INT8:
8477
# Scale output back to 8 bit
8578
# pyre-ignore
86-
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
79+
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
8780

8881

8982
@register_node_visitor
@@ -105,11 +98,19 @@ def define_node(
10598
output: TosaArg,
10699
is_quant_node: bool,
107100
) -> None:
108-
if is_quant_node:
101+
# Specification (0.80.0) states that input and output types
102+
# should all be the same
103+
assert inputs[0].dtype == inputs[1].dtype == output.dtype
104+
105+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
109106
# Call the inherited define_node for handling integers
110107
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
111108
else:
112109
# FP32 Add lowering
110+
assert inputs[0].dtype == ts.DType.FP32
111+
assert output.dtype == ts.DType.FP32
112+
113+
# MI lowering
113114
tosa_graph.addOperator(
114115
TosaOp.Op().ADD,
115116
[inputs[0].name, inputs[1].name],

backends/arm/operators/op_max.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
import serializer.tosa_serializer as ts
12+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13+
get_input_qparams,
14+
)
15+
from executorch.backends.arm.operators.node_visitor import (
16+
NodeVisitor,
17+
register_node_visitor,
18+
)
19+
from executorch.backends.arm.tosa_mapping import TosaArg
20+
from executorch.backends.arm.tosa_utils import tosa_shape
21+
22+
from serializer.tosa_serializer import TosaOp
23+
from torch.fx import Node
24+
25+
26+
@register_node_visitor
27+
class MaxVisitor(NodeVisitor):
28+
target = "aten.maximum.default"
29+
30+
def __init__(self, *args):
31+
super().__init__(*args)
32+
33+
def define_node(
34+
self,
35+
node: Node,
36+
tosa_graph: ts.TosaSerializer,
37+
inputs: List[TosaArg],
38+
output: TosaArg,
39+
is_quant_node: bool,
40+
) -> None:
41+
assert inputs[0].dtype == inputs[1].dtype
42+
43+
max_output = output
44+
if inputs[0].dtype == ts.DType.INT8:
45+
input_qparams = get_input_qparams(node)
46+
assert (
47+
len(input_qparams) == 2
48+
), f"Both inputs needs to have quantization information for {node}"
49+
# insert RESCALEs to int32
50+
assert (
51+
input_qparams[0] == input_qparams[1]
52+
), "Both inputs must have same quantization for MAX"
53+
54+
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
55+
tosa_graph, inputs, node
56+
)
57+
58+
output.shape = tosa_shape(output.shape, output.dim_order)
59+
max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
60+
else:
61+
operand_inputs = inputs
62+
63+
tosa_graph.addOperator(
64+
TosaOp.Op().MAXIMUM,
65+
[
66+
operand_inputs[0].name,
67+
operand_inputs[1].name,
68+
],
69+
[max_output.name],
70+
)
71+
72+
if output.dtype == ts.DType.INT8:
73+
# insert RESCALE from int32 back to int8
74+
tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node)

0 commit comments

Comments
 (0)