Skip to content

Commit ea7c42e

Browse files
authored
Merge branch 'main' into use-quantize_
2 parents 348224c + 7503bb3 commit ea7c42e

File tree

27 files changed

+1498
-785
lines changed

27 files changed

+1498
-785
lines changed

.ci/scripts/test_model.sh

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,14 @@ test_model_with_qnn() {
188188
EXPORT_SCRIPT=edsr
189189
# Additional deps for edsr
190190
pip install piq
191+
elif [[ "${MODEL_NAME}" == "albert" ]]; then
192+
EXPORT_SCRIPT=albert
193+
elif [[ "${MODEL_NAME}" == "bert" ]]; then
194+
EXPORT_SCRIPT=bert
195+
elif [[ "${MODEL_NAME}" == "distilbert" ]]; then
196+
EXPORT_SCRIPT=distilbert
197+
elif [[ "${MODEL_NAME}" == "eurobert" ]]; then
198+
EXPORT_SCRIPT=eurobert
191199
else
192200
echo "Unsupported model $MODEL_NAME"
193201
exit 1
@@ -197,7 +205,25 @@ test_model_with_qnn() {
197205
# TODO(guangyang): Make QNN chipset matches the target device
198206
QNN_CHIPSET=SM8450
199207

200-
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS
208+
SCRIPT_FOLDER=""
209+
case "${MODEL_NAME}" in
210+
"dl3"|"mv3"|"mv2"|"ic4"|"ic3"|"vit"|"mb"|"w2l")
211+
SCRIPT_FOLDER=scripts
212+
;;
213+
"albert"|"bert"|"distilbert")
214+
pip install evaluate
215+
SCRIPT_FOLDER=oss_scripts
216+
# Bert models running in 16bit will encounter op validation fail on some operations,
217+
# which requires CHIPSET >= SM8550.
218+
QNN_CHIPSET=SM8550
219+
;;
220+
*)
221+
echo "Unsupported model $MODEL_NAME"
222+
exit 1
223+
;;
224+
esac
225+
226+
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.${SCRIPT_FOLDER}.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS
201227
EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "${MODEL_NAME}*.pte" -print -quit)
202228
}
203229

.github/workflows/trunk.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,32 @@ jobs:
480480
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
481481
PYTHON_EXECUTABLE=python bash .ci/scripts/test_model.sh ${{ matrix.model }} "cmake" "qnn"
482482
483+
test-qnn-optimum-model:
484+
name: test-qnn-optimum-model
485+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
486+
permissions:
487+
id-token: write
488+
contents: read
489+
strategy:
490+
matrix:
491+
dtype: [fp32]
492+
model: [albert, bert, distilbert] # eurobert requires transfomer >= 4.48.0, skip for now
493+
fail-fast: false
494+
with:
495+
runner: linux.2xlarge
496+
docker-image: executorch-ubuntu-22.04-qnn-sdk
497+
submodules: 'recursive'
498+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
499+
timeout: 900
500+
script: |
501+
# The generic Linux job chooses to use base env, not the one setup by the image
502+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
503+
conda activate "${CONDA_ENV}"
504+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake
505+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh
506+
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
507+
PYTHON_EXECUTABLE=python bash .ci/scripts/test_model.sh ${{ matrix.model }} "cmake" "qnn"
508+
483509
test-apple-model:
484510
name: test-apple-model
485511
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ xcuserdata/
4242
*.xcworkspace/
4343
*.xcframework/
4444

45+
# clangd
46+
.cache/
47+
4548
# misc
4649
/.vscode/
4750
*.so

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from .annotate_quant_attrs import AnnotateQuantAttrs
99
from .annotate_stack import AnnotateStack
1010
from .annotate_unbind import AnnotateUnbind
11-
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1211
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
1312
from .convert_square_to_pow import ConvertSquareToPow
1413
from .decompose_any import DecomposeAny
@@ -19,6 +18,7 @@
1918
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
2019
from .decompose_roll import DecomposeRoll
2120
from .decompose_silu import DecomposeSilu
21+
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
2222
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
2323
from .fixed_linear_keep_dim import FixedLinearKeepDim
2424
from .fold_qdq import FoldQDQ
@@ -45,7 +45,6 @@
4545
AnnotateQuantAttrs,
4646
AnnotateStack,
4747
AnnotateUnbind,
48-
ConvertBmmToMatmul,
4948
ConvertConv1dToConv2d,
5049
ConvertSquareToPow,
5150
DecomposeAny,
@@ -56,6 +55,7 @@
5655
DecomposeLinalgVectorNorm,
5756
DecomposeRoll,
5857
DecomposeSilu,
58+
DecomposeWrapWithAutocast,
5959
ExpandBroadcastTensorShape,
6060
FixedLinearKeepDim,
6161
FoldQDQ,

backends/qualcomm/_passes/convert_bmm_to_matmul.py

Lines changed: 0 additions & 76 deletions
This file was deleted.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import _operator
8+
from typing import Dict, Tuple
9+
10+
import torch
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
from .utils import copy_nn_module_stack
14+
15+
16+
class DecomposeWrapWithAutocast(ExportPass):
17+
"""
18+
Decompose the _higher_order_ops WrapWithAutocast
19+
"""
20+
21+
def __init__(self) -> None:
22+
super().__init__()
23+
24+
def _get_submod(
25+
self, gm: torch.fx.GraphModule, node: torch.fx.Node
26+
) -> Tuple[torch.fx.GraphModule, str]:
27+
for a in node.args:
28+
if isinstance(a, torch.fx.Node) and "submod" in a.target:
29+
return getattr(gm, a.target), a.target
30+
31+
def _replace_output(
32+
self, wwac_node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict
33+
):
34+
for user in wwac_node.users.copy():
35+
arg_idx = 0
36+
is_user_getitem = False
37+
38+
if user.target == _operator.getitem:
39+
arg_idx = user.args[1]
40+
is_user_getitem = True
41+
42+
user.replace_input_with(
43+
wwac_node,
44+
remap[output_node.args[0][arg_idx]],
45+
)
46+
47+
if is_user_getitem:
48+
for user_user in user.users.copy():
49+
user_user.replace_input_with(user, user.args[0])
50+
51+
def _replace(self, gm: torch.fx.GraphModule) -> None:
52+
graph = gm.graph
53+
for node in graph.nodes:
54+
if isinstance(node.target, torch._higher_order_ops.wrap.WrapWithAutocast):
55+
submod, submod_name = self._get_submod(gm, node)
56+
n_args = node.args
57+
input_submod = n_args[4]
58+
decomposed_module = submod
59+
with graph.inserting_before(node):
60+
# remap is used to map original node values to new node values,
61+
# which ensures that reference to nodes are correctly updated in the new graph
62+
# remap = {"expand_1": node.args[5], "to_4": node.args[6]}
63+
remap = {n_args[i].name: n_args[i] for i in range(5, len(n_args))}
64+
65+
for decomposed_node in decomposed_module.graph.nodes:
66+
copy_nn_module_stack(node, decomposed_node)
67+
# no need to copy existent 'output'
68+
if decomposed_node.op == "output":
69+
self._replace_output(node, decomposed_node, remap)
70+
# no need to copy existent placeholders
71+
elif decomposed_node.op == "placeholder":
72+
# replace node map from string to graph node
73+
remap[decomposed_node] = remap.pop(decomposed_node.name)
74+
else:
75+
remap[decomposed_node] = graph.node_copy(
76+
decomposed_node,
77+
arg_transform=lambda x, remap=remap: remap[x],
78+
)
79+
80+
graph.erase_node(node)
81+
82+
graph.erase_node(input_submod)
83+
84+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
85+
self._replace(graph_module)
86+
graph_module.graph.eliminate_dead_code()
87+
graph_module.recompile()
88+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
AnnotateQuantAttrs,
1414
AnnotateStack,
1515
AnnotateUnbind,
16-
ConvertBmmToMatmul,
1716
ConvertConv1dToConv2d,
1817
ConvertSquareToPow,
1918
DecomposeAny,
@@ -24,6 +23,7 @@
2423
DecomposeLinalgVectorNorm,
2524
DecomposeRoll,
2625
DecomposeSilu,
26+
DecomposeWrapWithAutocast,
2727
ExpandBroadcastTensorShape,
2828
FixedLinearKeepDim,
2929
FoldQDQ,
@@ -80,7 +80,6 @@ def get_capture_program_passes():
8080
(AnnotateQuantAttrs, True),
8181
(AnnotateStack, True),
8282
(AnnotateUnbind, True),
83-
(ConvertBmmToMatmul, True),
8483
(ConvertConv1dToConv2d, True),
8584
(DecomposeAny, True),
8685
(DecomposeColIm, True),
@@ -194,6 +193,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
194193
self.add_pass(DecomposeScaledDotProductAttention())
195194
self.add_pass(DecomposeRoll())
196195
self.add_pass(DecomposeSilu())
196+
self.add_pass(DecomposeWrapWithAutocast())
197197
self.add_pass(DecomposeEinsum())
198198
self.add_pass(DecomposeExpM1())
199199
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
@@ -207,6 +207,7 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram):
207207
self.add_pass(DecomposeRoll())
208208
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
209209
self.add_pass(DecomposeExpM1())
210+
self.add_pass(DecomposeWrapWithAutocast())
210211
# this pass will rewrite state_dict, it needs to be accomplished before
211212
# to_edge_transform_and_lower
212213
self.add_pass(ConvertConv1dToConv2d(exported_program))

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def _dim_order_op_condition(self, node):
4343
dim_order = node.kwargs.get("dim_order")
4444
# skip if there contains layout hint
4545
# e.g. (0, 2, 3, 1) != (0, 1, 2, 3)
46+
if node.meta["val"].dtype != node.args[0].meta["val"].dtype:
47+
return False
4648
return dim_order != list(range(len(dim_order)))
4749

4850
def _to_copy_op_condition(self, node):
@@ -53,19 +55,15 @@ def _default_condition(self, ndoe):
5355

5456
def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
5557
for n in graph_module.graph.nodes:
56-
if n.target not in self.redundant_ops or not self.redundant_ops[n.target](
57-
n
58-
):
59-
continue
60-
61-
to_be_remove = n
62-
# assert_tensor_metadata op has no user
63-
if len(n.users.keys()) == 0:
64-
n.args = ()
65-
# normal case
66-
for user_n in list(n.users.keys()):
67-
user_n.replace_input_with(n, n.args[0])
68-
graph_module.graph.erase_node(to_be_remove)
58+
if n.target in self.redundant_ops and self.redundant_ops[n.target](n):
59+
to_be_remove = n
60+
# assert_tensor_metadata op has no user
61+
if len(n.users.keys()) == 0:
62+
n.args = ()
63+
# normal case
64+
for user_n in list(n.users.keys()):
65+
user_n.replace_input_with(n, n.args[0])
66+
graph_module.graph.erase_node(to_be_remove)
6967

7068
def call(self, graph_module: torch.fx.GraphModule):
7169
self._remove(graph_module)

0 commit comments

Comments
 (0)