Skip to content

Commit 4c80294

Browse files
Add support for concat q/dq folding
This is a special case where node.args can be lists with many incoming dq-nodes. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Icf511a8bdeaaffb597b18455ab7f1fbd947ce3ca
1 parent 45c1976 commit 4c80294

File tree

2 files changed

+47
-31
lines changed

2 files changed

+47
-31
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,15 @@ def transform_to_backend_pipeline(
9090
self.add_pass(
9191
FoldAndAnnotateQParamsPass(
9292
[
93-
exir_ops.edge.aten.minimum.default,
94-
exir_ops.edge.aten.maximum.default,
9593
exir_ops.edge.aten.add.Tensor,
9694
exir_ops.edge.aten.avg_pool2d.default,
95+
exir_ops.edge.aten.cat.default,
9796
exir_ops.edge.aten.convolution.default,
9897
exir_ops.edge.aten.exp.default,
9998
exir_ops.edge.aten.full.default,
10099
exir_ops.edge.aten.log.default,
100+
exir_ops.edge.aten.maximum.default,
101+
exir_ops.edge.aten.minimum.default,
101102
exir_ops.edge.aten.reciprocal.default,
102103
exir_ops.edge.aten.rsqrt.default,
103104
exir_ops.edge.aten.sigmoid.default,

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,46 @@ def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
8080
super().__init__()
8181
self.targeted_ops = targeted_ops
8282

83+
def fold_and_annotate_arg(
84+
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
85+
):
86+
input_qparams = None
87+
nodes_to_remove = set()
88+
for arg in arg_list:
89+
if not isinstance(arg, Node):
90+
return
91+
"""
92+
Make sure arg has requires_grad set to False
93+
For parameters that are not quantized, sometimes (i.e. convolution)
94+
the Parameter(FakeTensor(...)) has requires_grad set to True, which
95+
causes the retracing of the graph to fail with:
96+
97+
E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
98+
E
99+
E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
100+
E Original traceback:
101+
E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
102+
E x = conv(x)
103+
"""
104+
if arg.op == "placeholder":
105+
arg.meta["val"].requires_grad = False
106+
107+
arg_quant_params = None
108+
if arg.target == dq_op:
109+
arg_quant_params = QuantArgs.from_operator(arg.target, arg.args)
110+
# add arg to nodes_to_remove to fold the dq-node
111+
nodes_to_remove.add(arg)
112+
if input_qparams is not None and input_qparams != arg_quant_params:
113+
# Two args are quantized differently
114+
raise RuntimeError("Input qparams does not match!")
115+
input_qparams = arg_quant_params
116+
if input_qparams is not None:
117+
node.meta["input_qparams"][i] = input_qparams
118+
for n in nodes_to_remove:
119+
assert n.target == dq_op
120+
n.replace_all_uses_with(n.args[0])
121+
graph_module.graph.erase_node(n)
122+
83123
def call(self, graph_module: GraphModule) -> PassResult:
84124

85125
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
@@ -98,36 +138,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
98138
n.meta["input_qparams"] = {}
99139
n.meta["output_qparams"] = {}
100140
for i, arg in enumerate(n.args):
101-
if not isinstance(arg, Node):
102-
continue
103-
104-
# Make sure arg has requires_grad set to False
105-
# For parameters that are not quantized, sometimes (i.e. convolution)
106-
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
107-
# causes the retracing of the graph to fail with:
108-
#
109-
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
110-
# E
111-
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
112-
# E Original traceback:
113-
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
114-
# E x = conv(x)
115-
#
116-
if arg.op == "placeholder":
117-
arg.meta["val"].requires_grad = False
118-
119-
if arg.target != dq_op:
120-
continue
121-
122-
# arg.target for argument i is a dequant node, extract the information
123-
n.meta["input_qparams"][i] = QuantArgs.from_operator(
124-
arg.target, arg.args
125-
)
141+
if isinstance(arg, list):
142+
self.fold_and_annotate_arg(graph_module, n, arg, i)
126143

127-
# arg.args[0] is the tensor input, replace the input usage
128-
tensor_input = cast(Node, arg.args[0])
129-
n.replace_input_with(arg, tensor_input)
130-
graph_module.graph.erase_node(arg)
144+
elif isinstance(arg, Node):
145+
self.fold_and_annotate_arg(graph_module, n, [arg], i)
131146

132147
# Copy the users, since we are modifying it.
133148
users_copy = copy.copy(n.users)

0 commit comments

Comments
 (0)