Skip to content

Commit c6aa251

Browse files
chenweng-quichaowhsu-quic
authored andcommitted
Qualcomm AI Engine Direct - GA Model Enablement (deit)
Summary - support e2e script / test case for GA deit model - perf: 8a8w 2.7ms/inf - acc: top1/5 ~= 85%/95%
1 parent d5c4ba7 commit c6aa251

File tree

2 files changed

+182
-0
lines changed

2 files changed

+182
-0
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3880,6 +3880,41 @@ def test_conv_former(self):
38803880
self.assertGreaterEqual(msg["top_1"], 60)
38813881
self.assertGreaterEqual(msg["top_5"], 80)
38823882

3883+
def test_deit(self):
3884+
if not self.required_envs([self.image_dataset]):
3885+
self.skipTest("missing required envs")
3886+
cmds = [
3887+
"python",
3888+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/deit.py",
3889+
"--dataset",
3890+
self.image_dataset,
3891+
"--artifact",
3892+
self.artifact_dir,
3893+
"--build_folder",
3894+
self.build_folder,
3895+
"--device",
3896+
self.device,
3897+
"--model",
3898+
self.model,
3899+
"--ip",
3900+
self.ip,
3901+
"--port",
3902+
str(self.port),
3903+
]
3904+
if self.host:
3905+
cmds.extend(["--host", self.host])
3906+
3907+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3908+
with Listener((self.ip, self.port)) as listener:
3909+
conn = listener.accept()
3910+
p.communicate()
3911+
msg = json.loads(conn.recv())
3912+
if "Error" in msg:
3913+
self.fail(msg["Error"])
3914+
else:
3915+
self.assertGreaterEqual(msg["top_1"], 75)
3916+
self.assertGreaterEqual(msg["top_5"], 90)
3917+
38833918
def test_dino_v2(self):
38843919
if not self.required_envs([self.image_dataset]):
38853920
self.skipTest("missing required envs")
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import getpass
8+
import json
9+
import os
10+
from multiprocessing.connection import Client
11+
12+
import numpy as np
13+
import torch
14+
from executorch.backends.qualcomm._passes.qnn_pass_manager import (
15+
get_capture_program_passes,
16+
)
17+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
18+
from executorch.examples.qualcomm.utils import (
19+
build_executorch_binary,
20+
get_imagenet_dataset,
21+
make_output_dir,
22+
parse_skip_delegation_node,
23+
setup_common_args_and_variables,
24+
SimpleADB,
25+
topk_accuracy,
26+
)
27+
from transformers import AutoConfig, AutoModelForImageClassification
28+
29+
30+
def get_instance():
31+
module = (
32+
AutoModelForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
33+
.eval()
34+
.to("cpu")
35+
)
36+
37+
return module
38+
39+
40+
def main(args):
41+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
42+
43+
os.makedirs(args.artifact, exist_ok=True)
44+
config = AutoConfig.from_pretrained("facebook/deit-base-distilled-patch16-224")
45+
data_num = 100
46+
height = config.image_size
47+
width = config.image_size
48+
inputs, targets, input_list = get_imagenet_dataset(
49+
dataset_path=f"{args.dataset}",
50+
data_size=data_num,
51+
image_shape=(height, width),
52+
crop_size=(height, width),
53+
)
54+
55+
# Get the Deit model.
56+
model = get_instance()
57+
pte_filename = "deit_qnn"
58+
59+
# lower to QNN
60+
passes_job = get_capture_program_passes()
61+
build_executorch_binary(
62+
model,
63+
inputs[0],
64+
args.model,
65+
f"{args.artifact}/{pte_filename}",
66+
dataset=inputs,
67+
skip_node_id_set=skip_node_id_set,
68+
skip_node_op_set=skip_node_op_set,
69+
quant_dtype=QuantDtype.use_8a8w,
70+
passes_job=passes_job,
71+
shared_buffer=args.shared_buffer,
72+
)
73+
74+
if args.compile_only:
75+
return
76+
77+
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}"
78+
pte_path = f"{args.artifact}/{pte_filename}.pte"
79+
80+
adb = SimpleADB(
81+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
82+
build_path=f"{args.build_folder}",
83+
pte_path=pte_path,
84+
workspace=workspace,
85+
device_id=args.device,
86+
host_id=args.host,
87+
soc_model=args.model,
88+
)
89+
adb.push(inputs=inputs, input_list=input_list)
90+
adb.execute()
91+
92+
# collect output data
93+
output_data_folder = f"{args.artifact}/outputs"
94+
make_output_dir(output_data_folder)
95+
96+
adb.pull(output_path=args.artifact)
97+
98+
# top-k analysis
99+
predictions = []
100+
for i in range(data_num):
101+
predictions.append(
102+
np.fromfile(
103+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
104+
)
105+
)
106+
107+
k_val = [1, 5]
108+
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
109+
if args.ip and args.port != -1:
110+
with Client((args.ip, args.port)) as conn:
111+
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
112+
else:
113+
for i, k in enumerate(k_val):
114+
print(f"top_{k}->{topk[i]}%")
115+
116+
117+
if __name__ == "__main__":
118+
parser = setup_common_args_and_variables()
119+
parser.add_argument(
120+
"-a",
121+
"--artifact",
122+
help="path for storing generated artifacts and output by this example. Default ./deit_qnn",
123+
default="./deit_qnn",
124+
type=str,
125+
)
126+
127+
parser.add_argument(
128+
"-d",
129+
"--dataset",
130+
help=(
131+
"path to the validation folder of ImageNet dataset. "
132+
"e.g. --dataset imagenet-mini/val "
133+
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
134+
),
135+
type=str,
136+
required=True,
137+
)
138+
139+
args = parser.parse_args()
140+
try:
141+
main(args)
142+
except Exception as e:
143+
if args.ip and args.port != -1:
144+
with Client((args.ip, args.port)) as conn:
145+
conn.send(json.dumps({"Error": str(e)}))
146+
else:
147+
raise Exception(e)

0 commit comments

Comments
 (0)