Skip to content

Commit d292254

Browse files
committed
Qualcomm AI Engine Direct - fix for pytorch uplevel
Summary - use 'fold_quantize=False' in convert_pt2e to prevent overwriting state_dict during lowering - change in _get_updated_graph_siganture to have signature detected correctly
1 parent 304c567 commit d292254

File tree

13 files changed

+179
-136
lines changed

13 files changed

+179
-136
lines changed

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,17 @@
77
from typing import Any, Dict
88

99
import torch
10-
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
10+
from executorch.backends.qualcomm.builders.utils import get_parameter
1111
from executorch.backends.qualcomm.utils.constants import (
12-
QCOM_AXIS,
13-
QCOM_BLOCK_SIZE,
1412
QCOM_DTYPE,
1513
QCOM_ENCODING,
1614
QCOM_QUANT_ATTRS,
1715
QCOM_QUANT_MAX,
1816
QCOM_QUANT_MIN,
1917
QCOM_REQUANTIZE,
2018
QCOM_SCALE,
21-
QCOM_SCALES,
2219
QCOM_ZERO_POINT,
23-
QCOM_ZERO_POINTS,
2420
)
25-
from executorch.exir.dialects._ops import ops as exir_ops
2621
from executorch.exir.pass_base import ExportPass, PassResult
2722

2823
from .utils import dq_ops, get_quant_attrs, q_ops
@@ -86,6 +81,9 @@ def _annotate_requant(self, n):
8681
q_attrs = get_quant_attrs(self.edge_program, n)
8782
for dq_node in dq_nodes:
8883
dq_attrs = get_quant_attrs(self.edge_program, dq_node)
84+
# bypass parameters
85+
if n.args[0].op == "placeholder":
86+
continue
8987
# TODO: Store multiple pairs of requantize attributes when we have an op builder
9088
# that has multiple outputs that requires quant attributes.
9189
if self.skip_advanced_requant:
@@ -113,43 +111,9 @@ def _annotate_requant(self, n):
113111
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
114112
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
115113

116-
# Dequant all the fold_quant parameters back to fp32.
117-
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
118-
def _dequant_fold_params(self, n, quant_attrs, param):
119-
if quant_attrs[QCOM_ENCODING] in [
120-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
121-
]:
122-
dim, axis = param.dim(), quant_attrs[QCOM_AXIS]
123-
scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
124-
offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
125-
param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
126-
elif quant_attrs[QCOM_ENCODING] in [
127-
exir_ops.edge.pt2e_quant.dequantize_affine.default
128-
]:
129-
param = torch.ops.pt2e_quant.dequantize_affine(
130-
param,
131-
block_size=quant_attrs[QCOM_BLOCK_SIZE],
132-
scale=quant_attrs[QCOM_SCALE],
133-
zero_point=quant_attrs[QCOM_ZERO_POINT],
134-
input_dtype=quant_attrs[QCOM_DTYPE],
135-
quant_min=quant_attrs[QCOM_QUANT_MIN],
136-
quant_max=quant_attrs[QCOM_QUANT_MAX],
137-
output_dtype=torch.float32,
138-
)
139-
else:
140-
scale = quant_attrs[QCOM_SCALE]
141-
offset = quant_attrs[QCOM_ZERO_POINT]
142-
param = param.sub(offset).mul(scale).to(torch.float32).contiguous()
143-
144-
set_parameter(param, n.args[0], self.edge_program)
145-
n.args[0].meta["val"] = param
146-
147114
def _annotate_quant_attrs(
148115
self, graph_module: torch.fx.GraphModule
149116
) -> torch.fx.GraphModule:
150-
# Keep track of const params that has been dequant, so it does not get
151-
# dequant multiple times if the const param has more than 1 user
152-
visited_const_param = set()
153117
for n in graph_module.graph.nodes:
154118
self._annotate_requant(n)
155119
# With fold_quant enabled, check if the input of dq op is quantized param.
@@ -161,10 +125,6 @@ def _annotate_quant_attrs(
161125
quant_attrs = get_quant_attrs(self.edge_program, n)
162126
self._annotate_source_nodes(n, quant_attrs)
163127

164-
if param is not None and n.args[0] not in visited_const_param:
165-
visited_const_param.add(n.args[0])
166-
self._dequant_fold_params(n, quant_attrs, param)
167-
168128
return graph_module
169129

170130
def call(self, graph_module: torch.fx.GraphModule):

backends/qualcomm/_passes/convert_conv1d_to_conv2d.py

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.nn as nn
99
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
1010
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
11-
from executorch.exir.dialects._ops import ops as exir_ops
1211
from executorch.exir.pass_base import ExportPass, PassResult
1312

1413
from .utils import copy_meta
@@ -23,16 +22,43 @@ class ConvertConv1dToConv2d(ExportPass):
2322
def __init__(self, edge_program: torch.export.ExportedProgram):
2423
super(ConvertConv1dToConv2d, self).__init__()
2524
self.edge_program = edge_program
25+
self.conv_op_map = {
26+
torch.ops.aten.conv1d.default: torch.ops.aten.conv2d.default,
27+
torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input,
28+
}
29+
30+
def append_qdq(
31+
self,
32+
graph_module: torch.fx.GraphModule,
33+
node: torch.fx.Node,
34+
qdq_node: torch.fx.Node,
35+
):
36+
q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
37+
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
38+
if qdq_node.target not in {q_op, dq_op}:
39+
return node
40+
41+
with graph_module.graph.inserting_after(node):
42+
q_args = (node, *qdq_node.args[1:])
43+
q_node = graph_module.graph.create_node("call_function", q_op, q_args)
44+
q_node.meta = copy_meta(node.meta)
45+
q_node.meta["val"] = q_node.meta["val"].to(q_args[-1])
46+
with graph_module.graph.inserting_after(q_node):
47+
dq_args = (q_node, *qdq_node.args[1:])
48+
dq_node = graph_module.graph.create_node(
49+
"call_function", dq_op, dq_args
50+
)
51+
dq_node.meta = copy_meta(node.meta)
52+
53+
return dq_node
2654

2755
def call(self, graph_module: torch.fx.GraphModule):
2856
graph = graph_module.graph
29-
conv_op = exir_ops.edge.aten.convolution.default
3057
for node in graph.nodes:
31-
if node.target == conv_op and node.meta["val"].dim() == 3:
32-
58+
if node.target in self.conv_op_map:
3359
input_node = node.args[0]
3460
with graph_module.graph.inserting_after(input_node):
35-
unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default
61+
unsqueeze_op = torch.ops.aten.unsqueeze_copy.default
3662
unsqueeze_node = graph.create_node(
3763
"call_function",
3864
unsqueeze_op,
@@ -44,10 +70,19 @@ def call(self, graph_module: torch.fx.GraphModule):
4470
unsqueeze_node.meta = copy_meta(
4571
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
4672
)
73+
qdq_node_after_unsqueeze = self.append_qdq(
74+
graph_module=graph_module,
75+
node=unsqueeze_node,
76+
qdq_node=input_node,
77+
)
4778

48-
with graph_module.graph.inserting_after(unsqueeze_node):
49-
50-
filter_node = node.args[1]
79+
with graph_module.graph.inserting_after(qdq_node_after_unsqueeze):
80+
filter_arg = node.args[1]
81+
filter_node = (
82+
filter_arg
83+
if filter_arg.op == "placeholder"
84+
else node.args[1].args[0].args[0]
85+
)
5186
filter_node.meta["val"] = (
5287
filter_node.meta["val"].unsqueeze(2).contiguous()
5388
)
@@ -56,40 +91,59 @@ def call(self, graph_module: torch.fx.GraphModule):
5691
filter_tensor = nn.Parameter(filter_tensor.unsqueeze(2))
5792
set_parameter(filter_tensor, filter_node, self.edge_program)
5893

94+
num_args = len(node.args)
5995
bias_node = node.args[2]
60-
stride = [1] + node.args[3]
61-
padding = [0] + node.args[4]
62-
dilation = [1] + node.args[5]
63-
transpose = node.args[6]
64-
output_padding = [0] + node.args[7]
65-
groups = node.args[8]
66-
67-
conv2d_node = graph.create_node(
68-
"call_function",
69-
conv_op,
70-
(
71-
unsqueeze_node,
72-
filter_node,
96+
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
97+
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
98+
if node.target == torch.ops.aten.conv1d.default:
99+
dilation = [1] + node.args[5] if num_args > 5 else [1, 1]
100+
groups = node.args[6] if num_args > 5 else 1
101+
conv_args = (
102+
qdq_node_after_unsqueeze,
103+
node.args[1],
73104
bias_node,
74105
stride,
75106
padding,
76107
dilation,
77-
transpose,
108+
groups,
109+
)
110+
else:
111+
output_padding = (
112+
[0] + node.args[5] if num_args > 5 else [0, 0]
113+
)
114+
groups = node.args[6] if num_args > 6 else 1
115+
dilation = [1] + node.args[7] if num_args > 7 else [1, 1]
116+
conv_args = (
117+
qdq_node_after_unsqueeze,
118+
node.args[1],
119+
bias_node,
120+
stride,
121+
padding,
78122
output_padding,
79123
groups,
80-
),
124+
dilation,
125+
)
126+
conv2d_node = graph.create_node(
127+
"call_function",
128+
self.conv_op_map[node.target],
129+
conv_args,
81130
)
82131
conv2d_node.meta = copy_meta(
83132
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
84133
)
134+
qdq_node_after_conv2d = self.append_qdq(
135+
graph_module=graph_module,
136+
node=conv2d_node,
137+
qdq_node=list(node.users)[0],
138+
)
85139

86-
with graph_module.graph.inserting_after(conv2d_node):
87-
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
140+
with graph_module.graph.inserting_after(qdq_node_after_conv2d):
141+
squeeze_op = torch.ops.aten.squeeze_copy.dims
88142
squeeze_node = graph.create_node(
89143
"call_function",
90144
squeeze_op,
91145
(
92-
conv2d_node,
146+
qdq_node_after_conv2d,
93147
[2],
94148
),
95149
)
@@ -102,8 +156,10 @@ def call(self, graph_module: torch.fx.GraphModule):
102156
QCOM_REQUANTIZE
103157
]
104158
conv2d_node.meta.pop(QCOM_REQUANTIZE, None)
159+
105160
for user in node.users.copy():
106161
user.replace_input_with(node, squeeze_node)
162+
107163
graph.eliminate_dead_code()
108164
graph_module.recompile()
109165
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram):
200200
self.add_pass(DecomposeScaledDotProductAttention())
201201
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
202202
self.add_pass(DecomposeExpM1())
203+
# this pass will rewrite state_dict, it needs to be accomplished before
204+
# to_edge_transform_and_lower
205+
self.add_pass(ConvertConv1dToConv2d(exported_program))
203206
self.add_pass(ConvertSquareToPow())
204207
self.add_pass(LiftConstantScalarOperands())
205208
self._transform(exported_program.graph_module)

backends/qualcomm/_passes/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def get_passes_dependency_for_capture_program():
103103
AnnotateStack: [RemoveRedundancy],
104104
AnnotateUnbind: [RemoveRedundancy],
105105
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
106-
ConvertConv1dToConv2d: [FoldQDQ],
107106
ConvertUpsampleBicubicWithBilinear: [RemoveRedundancy],
108107
DecomposeAny: [RemoveRedundancy],
109108
DecomposeLinalgVectorNorm: [RemoveRedundancy],

backends/qualcomm/builders/op_batch_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def update_encoding(self, node: torch.fx.Node, tensor: torch.Tensor, eps):
4040
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
4141
# scale value equals to zero will cause failure in HTP
4242
diff = max(abs(tensor.max()), abs(tensor.min())) + eps
43-
quant_attrs[QCOM_SCALE] = diff / quant_attrs[QCOM_QUANT_MAX]
43+
quant_attrs[QCOM_SCALE] = (diff / quant_attrs[QCOM_QUANT_MAX]).item()
4444

4545
def define_node(
4646
self,

0 commit comments

Comments
 (0)