Skip to content

Commit acfd652

Browse files
committed
Revert "Revert "Add full operator to fold dq/q handling" (#7351)"
This reverts commit 11beed1.
1 parent eca5d9f commit acfd652

24 files changed

+963
-145
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
DecomposeSoftmaxesPass,
3030
)
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
33+
FoldAndAnnotateQParamsPass,
34+
QuantizeFullArgument,
35+
)
3236
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3337
KeepDimsFalseToSqueezePass,
3438
)
@@ -50,6 +54,7 @@
5054
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5155
from executorch.exir import ExportedProgram
5256
from executorch.exir.backend.compile_spec_schema import CompileSpec
57+
from executorch.exir.dialects._ops import ops as exir_ops
5358
from executorch.exir.pass_manager import PassManager
5459

5560

@@ -80,6 +85,19 @@ def transform_to_backend_pipeline(
8085
self.add_pass(Conv1dUnsqueezePass(exported_program))
8186
self.add_pass(DecomposeSoftmaxesPass())
8287
self.add_pass(DecomposeLinearPass())
88+
self.add_pass(QuantizeFullArgument())
89+
self.add_pass(
90+
FoldAndAnnotateQParamsPass(
91+
[
92+
exir_ops.edge.aten.minimum.default,
93+
exir_ops.edge.aten.maximum.default,
94+
exir_ops.edge.aten.add.Tensor,
95+
exir_ops.edge.aten.avg_pool2d.default,
96+
exir_ops.edge.aten.convolution.default,
97+
exir_ops.edge.aten.full.default,
98+
]
99+
)
100+
)
83101
for spec in compile_spec:
84102
if spec.key == "permute_memory_format":
85103
memory_format = spec.value.decode()
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
from typing import Callable, cast, Iterable
10+
11+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
from torch.fx import GraphModule, Node
17+
18+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
19+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
20+
21+
22+
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
23+
"""
24+
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
25+
Raises a ValueError if the node doesn't have any parameters set.
26+
"""
27+
if "input_qparams" not in node.meta.keys():
28+
raise ValueError(f"No input quantization parameter found in node {node}")
29+
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
30+
if len(input_qparams) == 0:
31+
raise ValueError(f"No input quantization parameter found in node {node}")
32+
return input_qparams
33+
34+
35+
def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
36+
"""
37+
Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
38+
Raises a ValueError if the node doesn't have any parameters set.
39+
"""
40+
if "output_qparams" not in node.meta.keys():
41+
raise ValueError(f"No output quantization parameter found in node {node}")
42+
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
43+
if len(input_qparams) == 0:
44+
raise ValueError(f"No output quantization parameter found in node {node}")
45+
return input_qparams
46+
47+
48+
class FoldAndAnnotateQParamsPass(ExportPass):
49+
"""
50+
A pass that walks the graph and removes any DQ and Q nodes before and after the target
51+
node in the supplied list of operators.
52+
The quantization parameters from the DQ/Q nodes are stored as meta values to be
53+
accessible for later lowering and serialization passes.
54+
The assumption is that the quantization annotatation adds DQ nodes for all tensor
55+
inputs to the target one Q node to the output.
56+
57+
Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):
58+
59+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
60+
61+
x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8)
62+
aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq)
63+
aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8)
64+
65+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
66+
67+
Becomes:
68+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
69+
70+
aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q)
71+
72+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
73+
74+
The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.
75+
76+
"""
77+
78+
def __init__(self, targeted_ops: Iterable[Callable]):
79+
super().__init__()
80+
self.targeted_ops = targeted_ops
81+
82+
def call(self, graph_module: GraphModule) -> PassResult:
83+
84+
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
85+
for n in graph_module.graph.nodes:
86+
n = cast(Node, n)
87+
if n.op != "call_function" or n.target not in self.targeted_ops:
88+
continue
89+
90+
# Make sure we haven't already set qparams meta information on the node
91+
assert "input_qparams" not in n.meta.keys()
92+
assert "output_qparams" not in n.meta.keys()
93+
94+
# for the inputs and outputs search the graph for quantization info and
95+
# store the information in a dict with order of the _tensor_ inputs as key,
96+
# ignoring any other arguments to the target node.
97+
n.meta["input_qparams"] = {}
98+
n.meta["output_qparams"] = {}
99+
for i, arg in enumerate(n.args):
100+
if not isinstance(arg, Node):
101+
continue
102+
103+
# Make sure arg has requires_grad set to False
104+
# For parameters that are not quantized, sometimes (i.e. convolution)
105+
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
106+
# causes the retracing of the graph to fail with:
107+
#
108+
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
109+
# E
110+
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
111+
# E Original traceback:
112+
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
113+
# E x = conv(x)
114+
#
115+
if arg.op == "placeholder":
116+
arg.meta["val"].requires_grad = False
117+
118+
if arg.target != dq_op:
119+
continue
120+
121+
# arg.target for argument i is a dequant node, extract the information
122+
n.meta["input_qparams"][i] = QuantArgs.from_operator(
123+
arg.target, arg.args
124+
)
125+
126+
# arg.args[0] is the tensor input, replace the input usage
127+
n.replace_input_with(arg, arg.args[0])
128+
graph_module.graph.erase_node(arg)
129+
130+
# Copy the users, since we are modifying it.
131+
users_copy = copy.copy(n.users)
132+
for i, user in enumerate(users_copy):
133+
if user.target != q_op:
134+
continue
135+
136+
# quantization node found here, store the quantization parameters in meta value
137+
n.meta["output_qparams"][i] = QuantArgs.from_operator(
138+
user.target, user.args
139+
)
140+
141+
user.replace_all_uses_with(n)
142+
graph_module.graph.erase_node(user)
143+
144+
# retrace the graph to update the fake tensor types
145+
graph_module = super().call(graph_module).graph_module
146+
147+
graph_module.recompile()
148+
return PassResult(graph_module, True)
149+
150+
151+
class QuantizeFullArgument(ExportPass):
152+
"""
153+
Make sure the fill_value for full.default is quantized. This pass needs to be run before
154+
the folding pass above to make sure that the retraced output of the full.default op is
155+
the right dtype.
156+
"""
157+
158+
def call(self, graph_module: GraphModule) -> PassResult:
159+
modified = False
160+
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
161+
for n in graph_module.graph.nodes:
162+
n = cast(Node, n)
163+
if n.target != exir_ops.edge.aten.full.default:
164+
continue
165+
166+
# Make sure we have a quantized operator
167+
user = list(n.users)[0]
168+
if user.target != q_op:
169+
continue
170+
171+
qargs = QuantArgs.from_operator(user.target, user.args)
172+
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
173+
# replace the node arg with a quantized dito and also set dtype
174+
# to get the right output according to the Edge IR specification:
175+
# exir/dialects/edge/edge.yaml:3596
176+
quantized_full_value = qargs.quantize_value(n.args[1]).item()
177+
n.update_arg(1, quantized_full_value)
178+
n.update_kwarg("dtype", qargs.dtype)
179+
modified = True
180+
181+
return PassResult(graph_module, modified)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
9494
exir_ops.edge.aten.sigmoid.default,
9595
exir_ops.edge.aten.mean.dim,
9696
exir_ops.edge.aten.mm.default,
97+
exir_ops.edge.aten.minimum.default,
98+
exir_ops.edge.aten.maximum.default,
9799
exir_ops.edge.aten.repeat.default,
98100
exir_ops.edge.aten.reciprocal.default,
99101
exir_ops.edge.aten.relu.default,

backends/arm/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
op_get_item,
2020
op_hardtanh,
2121
op_log,
22+
op_max,
2223
op_max_pool2d,
24+
op_min,
2325
op_mm,
2426
op_mul,
2527
op_permute,

backends/arm/operators/op_add.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import executorch.backends.arm.tosa_utils as tutils
1212

1313
import serializer.tosa_serializer as ts
14-
import torch
1514
from executorch.backends.arm.operators.node_visitor import (
1615
NodeVisitor,
1716
register_node_visitor,
@@ -41,33 +40,27 @@ def define_node(
4140
output: TosaArg,
4241
is_quant_node: bool,
4342
) -> None:
44-
input_nodes = tutils.get_two_inputs(node)
45-
46-
if not is_quant_node and not all(
47-
tensor.meta["val"].dtype in (torch.int8, torch.int32)
48-
for tensor in input_nodes
49-
):
50-
raise RuntimeError(
51-
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
52-
)
53-
54-
needs_rescale = not (
55-
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
56-
and node.meta["val"].dtype == torch.int32
57-
)
58-
59-
if needs_rescale:
60-
# Rescale inputs to 32 bit
61-
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
62-
input_nodes, tosa_graph
43+
# Specification (0.80) states that input and output types
44+
# should all be the same
45+
assert inputs[0].dtype == inputs[1].dtype == output.dtype
46+
# Handle int8 (quantized) and int32
47+
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
48+
49+
if inputs[0].dtype == ts.DType.INT8:
50+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
51+
tosa_graph, inputs, node
6352
)
53+
else:
54+
# input[0].dtype == ts.DType.INT32
55+
# Non quantized input, natively support by TOSA.ADD
56+
rescaled_inputs = inputs
6457

65-
# Prepare add output tensor
58+
if output.dtype == ts.DType.INT8:
6659
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6760
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
6861
else:
62+
# output.dtype == ts.DType.INT32
6963
add_output = output
70-
rescaled_inputs = inputs
7164

7265
# Do the INT32 Add
7366
tosa_graph.addOperator(
@@ -80,10 +73,10 @@ def define_node(
8073
None,
8174
)
8275

83-
if needs_rescale:
76+
if output.dtype == ts.DType.INT8:
8477
# Scale output back to 8 bit
8578
# pyre-ignore
86-
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
79+
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
8780

8881

8982
@register_node_visitor
@@ -105,11 +98,19 @@ def define_node(
10598
output: TosaArg,
10699
is_quant_node: bool,
107100
) -> None:
108-
if is_quant_node:
101+
# Specification (0.80) states that input and output types
102+
# should all be the same
103+
assert inputs[0].dtype == inputs[1].dtype == output.dtype
104+
105+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
109106
# Call the inherited define_node for handling integers
110107
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
111108
else:
112109
# FP32 Add lowering
110+
assert inputs[0].dtype == ts.DType.FP32
111+
assert output.dtype == ts.DType.FP32
112+
113+
# MI lowering
113114
tosa_graph.addOperator(
114115
TosaOp.Op().ADD,
115116
[inputs[0].name, inputs[1].name],

0 commit comments

Comments
 (0)