Skip to content

Commit 1f04bad

Browse files
Add support for concat q/dq folding
This is a special case where node.args can be lists with many incoming dq-nodes. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Icf511a8bdeaaffb597b18455ab7f1fbd947ce3ca
1 parent 7373bd6 commit 1f04bad

File tree

2 files changed

+47
-30
lines changed

2 files changed

+47
-30
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,15 @@ def transform_to_backend_pipeline(
9090
self.add_pass(
9191
FoldAndAnnotateQParamsPass(
9292
[
93-
exir_ops.edge.aten.minimum.default,
94-
exir_ops.edge.aten.maximum.default,
9593
exir_ops.edge.aten.add.Tensor,
9694
exir_ops.edge.aten.avg_pool2d.default,
95+
exir_ops.edge.aten.cat.default,
9796
exir_ops.edge.aten.convolution.default,
9897
exir_ops.edge.aten.exp.default,
9998
exir_ops.edge.aten.full.default,
10099
exir_ops.edge.aten.log.default,
100+
exir_ops.edge.aten.maximum.default,
101+
exir_ops.edge.aten.minimum.default,
101102
exir_ops.edge.aten.reciprocal.default,
102103
exir_ops.edge.aten.rsqrt.default,
103104
exir_ops.edge.aten.sigmoid.default,

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,46 @@ def __init__(self, targeted_ops: Iterable[Callable]):
7979
super().__init__()
8080
self.targeted_ops = targeted_ops
8181

82+
def fold_and_annotate_arg(
83+
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
84+
):
85+
input_qparams = None
86+
nodes_to_remove = set()
87+
for arg in arg_list:
88+
if not isinstance(arg, Node):
89+
return
90+
"""
91+
Make sure arg has requires_grad set to False
92+
For parameters that are not quantized, sometimes (i.e. convolution)
93+
the Parameter(FakeTensor(...)) has requires_grad set to True, which
94+
causes the retracing of the graph to fail with:
95+
96+
E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
97+
E
98+
E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
99+
E Original traceback:
100+
E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
101+
E x = conv(x)
102+
"""
103+
if arg.op == "placeholder":
104+
arg.meta["val"].requires_grad = False
105+
106+
arg_quant_params = None
107+
if arg.target == dq_op:
108+
arg_quant_params = QuantArgs.from_operator(arg.target, arg.args)
109+
# add arg to nodes_to_remove to fold the dq-node
110+
nodes_to_remove.add(arg)
111+
if input_qparams is not None and input_qparams != arg_quant_params:
112+
# Two args are quantized differently
113+
raise RuntimeError("Input qparams does not match!")
114+
input_qparams = arg_quant_params
115+
if input_qparams is not None:
116+
node.meta["input_qparams"][i] = input_qparams
117+
for n in nodes_to_remove:
118+
assert n.target == dq_op
119+
n.replace_all_uses_with(n.args[0])
120+
graph_module.graph.erase_node(n)
121+
82122
def call(self, graph_module: GraphModule) -> PassResult:
83123

84124
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
@@ -97,35 +137,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
97137
n.meta["input_qparams"] = {}
98138
n.meta["output_qparams"] = {}
99139
for i, arg in enumerate(n.args):
100-
if not isinstance(arg, Node):
101-
continue
102-
103-
# Make sure arg has requires_grad set to False
104-
# For parameters that are not quantized, sometimes (i.e. convolution)
105-
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
106-
# causes the retracing of the graph to fail with:
107-
#
108-
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
109-
# E
110-
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
111-
# E Original traceback:
112-
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
113-
# E x = conv(x)
114-
#
115-
if arg.op == "placeholder":
116-
arg.meta["val"].requires_grad = False
117-
118-
if arg.target != dq_op:
119-
continue
120-
121-
# arg.target for argument i is a dequant node, extract the information
122-
n.meta["input_qparams"][i] = QuantArgs.from_operator(
123-
arg.target, arg.args
124-
)
140+
if isinstance(arg, list):
141+
self.fold_and_annotate_arg(graph_module, n, arg, i)
125142

126-
# arg.args[0] is the tensor input, replace the input usage
127-
n.replace_input_with(arg, arg.args[0])
128-
graph_module.graph.erase_node(arg)
143+
elif isinstance(arg, Node):
144+
self.fold_and_annotate_arg(graph_module, n, [arg], i)
129145

130146
# Copy the users, since we are modifying it.
131147
users_copy = copy.copy(n.users)

0 commit comments

Comments
 (0)