Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/test_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ test_model_with_qnn() {
# TODO(guangyang): Make QNN chipset matches the target device
QNN_CHIPSET=SM8450

"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only $EXTRA_FLAGS
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS
EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "${MODEL_NAME}*.pte" -print -quit)
}

Expand Down
46 changes: 17 additions & 29 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@
class AnnotateQuantAttrs(ExportPass):
"""
Add "quant_attrs" to graph nodes' meta from the QDQ information
generated after quatization process.
generated after quantization process.
"""

def __init__(
self, edge_program: torch.export.ExportedProgram, skip_advanced_requat: bool
):
def __init__(self, edge_program: torch.export.ExportedProgram):
super(AnnotateQuantAttrs, self).__init__()
self.edge_program = edge_program
self.skip_advanced_requant = skip_advanced_requat

def _annotate_source_nodes(
self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any]
Expand Down Expand Up @@ -88,30 +85,21 @@ def _annotate_requant(self, n):
dq_attrs = get_quant_attrs(self.edge_program, dq_node)
# TODO: Store multiple pairs of requantize attributes when we have an op builder
# that has multiple outputs that requires quant attributes.
if self.skip_advanced_requant:
if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
user_node = list(dq_node.users)[0]
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
else:
# When dtype is the same but other specs such as scale and offset are different,
# insert requant to improve accuracy.
# Users can turn this feature off if any inference speed drop is observed.
if any(
q_attrs[attr] != dq_attrs[attr]
for attr in [
QCOM_SCALE,
QCOM_ZERO_POINT,
QCOM_QUANT_MIN,
QCOM_QUANT_MAX,
QCOM_DTYPE,
]
):
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
user_node = list(dq_node.users)[0]
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs

if any(
q_attrs[attr] != dq_attrs[attr]
for attr in [
QCOM_SCALE,
QCOM_ZERO_POINT,
QCOM_QUANT_MIN,
QCOM_QUANT_MAX,
QCOM_DTYPE,
]
):
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
user_node = list(dq_node.users)[0]
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs

# Dequant all the fold_quant parameters back to fp32.
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
Expand Down
6 changes: 5 additions & 1 deletion examples/qualcomm/scripts/deeplab_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
import random
import re
Expand Down Expand Up @@ -74,8 +75,11 @@ def main(args):
)

data_num = 100
if args.compile_only:
if args.ci:
inputs = [(torch.rand(1, 3, 224, 224),)]
logging.warning(
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
inputs, targets, input_list = get_dataset(
data_size=data_num, dataset_dir=args.artifact, download=args.download
Expand Down
6 changes: 5 additions & 1 deletion examples/qualcomm/scripts/edsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
import re
from multiprocessing.connection import Client
Expand Down Expand Up @@ -113,8 +114,11 @@ def main(args):
)

instance = EdsrModel()
if args.compile_only:
if args.ci:
inputs = instance.get_example_inputs()
logging.warning(
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
dataset = get_dataset(
args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact
Expand Down
6 changes: 5 additions & 1 deletion examples/qualcomm/scripts/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
from multiprocessing.connection import Client

Expand Down Expand Up @@ -37,8 +38,11 @@ def main(args):
)

data_num = 100
if args.compile_only:
if args.ci:
inputs = [(torch.rand(1, 3, 224, 224),)]
logging.warning(
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
inputs, targets, input_list = get_imagenet_dataset(
dataset_path=f"{args.dataset}",
Expand Down
8 changes: 6 additions & 2 deletions examples/qualcomm/scripts/inception_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
from multiprocessing.connection import Client

Expand Down Expand Up @@ -37,8 +38,11 @@ def main(args):
)

data_num = 100
if args.compile_only:
inputs = [(torch.rand(1, 3, 299, 299),)]
if args.ci:
inputs = [(torch.rand(1, 3, 224, 224),)]
logging.warning(
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
inputs, targets, input_list = get_imagenet_dataset(
dataset_path=f"{args.dataset}",
Expand Down
6 changes: 5 additions & 1 deletion examples/qualcomm/scripts/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
from multiprocessing.connection import Client

Expand Down Expand Up @@ -37,8 +38,11 @@ def main(args):
)

data_num = 100
if args.compile_only:
if args.ci:
inputs = [(torch.rand(1, 3, 224, 224),)]
logging.warning(
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
inputs, targets, input_list = get_imagenet_dataset(
dataset_path=f"{args.dataset}",
Expand Down
6 changes: 5 additions & 1 deletion examples/qualcomm/scripts/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
from multiprocessing.connection import Client

Expand Down Expand Up @@ -36,8 +37,11 @@ def main(args):
)

data_num = 100
if args.compile_only:
if args.ci:
inputs = [(torch.rand(1, 3, 224, 224),)]
logging.warning(
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
inputs, targets, input_list = get_imagenet_dataset(
dataset_path=f"{args.dataset}",
Expand Down
6 changes: 5 additions & 1 deletion examples/qualcomm/scripts/torchvision_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
from multiprocessing.connection import Client

Expand All @@ -28,8 +29,11 @@ def main(args):
os.makedirs(args.artifact, exist_ok=True)

data_num = 100
if args.compile_only:
if args.ci:
inputs = [(torch.rand(1, 3, 224, 224),)]
logging.warning(
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
inputs, targets, input_list = get_imagenet_dataset(
dataset_path=f"{args.dataset}",
Expand Down
4 changes: 2 additions & 2 deletions examples/qualcomm/scripts/wav2letter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ def main(args):

# retrieve dataset, will take some time to download
data_num = 100
if args.compile_only:
if args.ci:
inputs = [(torch.rand(1, 1, 700, 1),)]
logging.warning(
"With compile_only, accuracy will be bad due to insufficient datasets for quantization."
"This option is for CI to verify the export flow. It uses random input and will result in poor accuracy."
)
else:
inputs, targets, input_list = get_dataset(
Expand Down
7 changes: 7 additions & 0 deletions examples/qualcomm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,13 @@ def setup_common_args_and_variables():
action="store_true",
)

parser.add_argument(
"--ci",
help="This flag is for Continuous Integration(CI) purpose and is NOT recommended to turn on for typical use cases. It will use random inputs instead of real inputs.",
action="store_true",
default=False,
)

# QNN_SDK_ROOT might also be an argument, but it is used in various places.
# So maybe it's fine to just use the environment.
if "QNN_SDK_ROOT" not in os.environ:
Expand Down
Loading