Skip to content

Commit a8a7e1b

Browse files
Qualcomm AI Engine Direct - GA model enablement (Swin Transformer) (#11099)
### Summary - e2e script / test case for GA Swin Transformer model (https://huggingface.co/microsoft/swin-tiny-patch4-window7-224) - perf: 8a8w 3.9ms/inf (SM8750) - acc: top1/5 ~= 70%/85% in ImageNet (https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000) - add pass to handle torch.roll - properly annotate aten.masked_fill - fix embedding UT in quantization flow ### Test plan ``` python examples/qualcomm/oss_scripts/swin_transformer.py -b build-android -H $HOST -s $DEVICE -m $MODEL --dataset ${PATH_TO_DATASET} ```
1 parent 3351ecc commit a8a7e1b

20 files changed

+536
-37
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .annotate_adaptive_avg_pool1d import AnnotateAdaptiveAvgPool1D
78
from .annotate_quant_attrs import AnnotateQuantAttrs
89
from .annotate_stack import AnnotateStack
910
from .annotate_unbind import AnnotateUnbind
@@ -16,6 +17,7 @@
1617
from .decompose_einsum import DecomposeEinsum
1718
from .decompose_expm1 import DecomposeExpM1
1819
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
20+
from .decompose_roll import DecomposeRoll
1921
from .decompose_silu import DecomposeSilu
2022
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
2123
from .fixed_linear_keep_dim import FixedLinearKeepDim
@@ -39,6 +41,7 @@
3941

4042

4143
__all__ = [
44+
AnnotateAdaptiveAvgPool1D,
4245
AnnotateQuantAttrs,
4346
AnnotateStack,
4447
AnnotateUnbind,
@@ -51,6 +54,7 @@
5154
DecomposeEinsum,
5255
DecomposeExpM1,
5356
DecomposeLinalgVectorNorm,
57+
DecomposeRoll,
5458
DecomposeSilu,
5559
ExpandBroadcastTensorShape,
5660
FixedLinearKeepDim,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
8+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
11+
12+
from .utils import get_quant_attrs
13+
14+
15+
class AnnotateAdaptiveAvgPool1D(ExportPass):
16+
"""
17+
Add "quant_attrs" to graph nodes' meta from the QDQ information
18+
generated after quantization process.
19+
adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
20+
"""
21+
22+
decomp_ops = [torch.ops.aten.adaptive_avg_pool2d.default]
23+
24+
def __init__(self, edge_program: torch.export.ExportedProgram):
25+
super(AnnotateAdaptiveAvgPool1D, self).__init__()
26+
self.edge_program = edge_program
27+
28+
def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
29+
partitions = get_source_partitions(
30+
graph_module.graph, [torch.ops.aten.adaptive_avg_pool1d.default]
31+
)
32+
for src_partitions in partitions.values():
33+
for src_partition in src_partitions:
34+
output = src_partition.output_nodes[0]
35+
if (list(output.users)[0].target) in q_ops:
36+
quant_attrs = get_quant_attrs(
37+
self.edge_program, list(output.users)[0]
38+
)
39+
for n in src_partition.nodes:
40+
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
41+
42+
def call(self, graph_module: torch.fx.GraphModule):
43+
self._annotate_adaptive_avg_pool1d(graph_module)
44+
graph_module.recompile()
45+
return PassResult(graph_module, True)

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict
88

99
import torch
10+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
1011
from executorch.backends.qualcomm.builders.utils import get_parameter
1112
from executorch.backends.qualcomm.utils.constants import (
1213
QCOM_DTYPE,
@@ -20,7 +21,7 @@
2021
)
2122
from executorch.exir.pass_base import ExportPass, PassResult
2223

23-
from .utils import dq_ops, get_quant_attrs, q_ops
24+
from .utils import get_quant_attrs
2425

2526

2627
class AnnotateQuantAttrs(ExportPass):

backends/qualcomm/_passes/annotate_stack.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
78
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1011

11-
from .utils import get_quant_attrs, q_ops
12+
from .utils import get_quant_attrs
1213

1314

1415
class AnnotateStack(ExportPass):

backends/qualcomm/_passes/annotate_unbind.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
78
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1011

11-
from .utils import dq_ops, get_quant_attrs
12+
from .utils import get_quant_attrs
1213

1314

1415
class AnnotateUnbind(ExportPass):
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import copy_nn_module_stack
11+
12+
13+
class SliceCopy(torch.nn.Module):
14+
def __init__(self, val_shape, shifts, dims):
15+
super().__init__()
16+
self.val_shape = val_shape
17+
if dims[0] is None:
18+
self.shifts = [shifts[0] % torch.numel(torch.tensor(val_shape))]
19+
else:
20+
self.shifts = [shift % val_shape[dim] for shift, dim in zip(shifts, dims)]
21+
self.dims = dims
22+
23+
def forward(self, x):
24+
if self.dims[0] is None:
25+
y = x.flatten()
26+
y = torch.cat((y[-self.shifts[0] :], y[: -self.shifts[0]]))
27+
return y.view(self.val_shape)
28+
29+
for shift, dim in zip(self.shifts, self.dims):
30+
x = torch.cat(
31+
(
32+
x[(slice(None),) * dim + (slice(-shift, None),)],
33+
x[(slice(None),) * dim + (slice(0, -shift),)],
34+
),
35+
dim=dim,
36+
)
37+
return x
38+
39+
40+
class DecomposeRoll(ExportPass):
41+
"""
42+
Decompose roll into slice and cat.
43+
"""
44+
45+
def __init__(self) -> None:
46+
super().__init__()
47+
48+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
49+
graph = graph_module.graph
50+
for node in graph.nodes:
51+
if "roll" in str(node.target):
52+
input_node, shifts = node.args[0], node.args[1]
53+
dims = node.args[2] if len(node.args) == 3 else None
54+
55+
# Normalize shifts and dims to lists
56+
shifts = shifts if isinstance(shifts, (list, tuple)) else [shifts]
57+
dims = dims if isinstance(dims, (list, tuple)) else [dims]
58+
59+
model = SliceCopy(input_node.meta["val"].shape, shifts, dims)
60+
decomposed_module = torch.export.export(
61+
model, (input_node.meta["val"],), strict=True
62+
).module()
63+
64+
with graph.inserting_before(node):
65+
# remap is used to map original node values to new node values,
66+
# which ensures that reference to nodes are correctly updated in the new graph
67+
remap = {"x": input_node}
68+
69+
for decomposed_node in decomposed_module.graph.nodes:
70+
copy_nn_module_stack(node, decomposed_node)
71+
# no need to copy existent 'output'
72+
if decomposed_node.op == "output":
73+
for user in node.users.copy():
74+
# remap
75+
user.replace_input_with(
76+
node,
77+
remap[decomposed_node.args[0][0]],
78+
)
79+
# no need to copy existent placeholders
80+
elif decomposed_node.op == "placeholder":
81+
# replace node map from string to graph node
82+
remap[decomposed_node] = remap.pop(decomposed_node.name)
83+
else:
84+
remap[decomposed_node] = graph.node_copy(
85+
decomposed_node,
86+
arg_transform=lambda x, remap=remap: remap[x],
87+
)
88+
89+
graph.erase_node(node)
90+
91+
graph.eliminate_dead_code()
92+
graph_module.recompile()
93+
return PassResult(graph_module, True)

backends/qualcomm/_passes/expand_broadcast_tensor_shape.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
89
from executorch.exir.dialects._ops import ops as exir_ops
910
from executorch.exir.pass_base import ExportPass, PassResult
1011
from executorch.exir.passes import dead_code_elimination_pass
1112

12-
from .utils import dq_ops
13-
1413

1514
class ExpandBroadcastTensorShape(ExportPass):
1615
"""

backends/qualcomm/_passes/fold_qdq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
78
from executorch.backends.qualcomm.builders.utils import is_parameter
89
from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE
910
from executorch.exir.dialects._ops import ops as exir_ops
1011
from executorch.exir.pass_base import ExportPass, PassResult
1112
from executorch.exir.passes import dead_code_elimination_pass
1213

13-
from .utils import dq_ops, q_ops
14-
1514

1615
class FoldQDQ(ExportPass):
1716
"""

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99

10+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
11+
1012
from executorch.backends.qualcomm.builders.utils import is_parameter
1113
from executorch.backends.qualcomm.utils.constants import (
1214
QCOM_ENCODING,
@@ -16,8 +18,6 @@
1618
from executorch.exir.dialects._ops import ops as exir_ops
1719
from executorch.exir.pass_base import ExportPass, PassResult
1820

19-
from .utils import q_ops
20-
2121

2222
class InsertIOQDQ(ExportPass):
2323
"""

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class TensorOpInfo:
5050
aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False),
5151
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
5252
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
53+
aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),
5354
}
5455

5556

@@ -78,7 +79,7 @@ def _build_tensor_constant(
7879
# For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
7980
# Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
8081
tensor = torch.tensor(
81-
[const_val],
82+
const_val,
8283
dtype=(
8384
node.args[0].meta["val"].dtype
8485
if not is_float_tensor(node)

0 commit comments

Comments
 (0)