Skip to content

Commit c230eea

Browse files
Merge branch 'main' into xinyuye/yolov4
2 parents 328fe12 + 26fe75e commit c230eea

File tree

3 files changed

+69
-63
lines changed

3 files changed

+69
-63
lines changed

vision/classification/shufflenet/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ONNX ShuffleNet-v2 ==> Quantized ONNX ShuffleNet-v2
3030
|-------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------|
3131
|ShuffleNet-v2 |[9.2MB](model/shufflenet-v2-10.onnx) | [8.7MB](model/shufflenet-v2-10.tar.gz) | 1.6 | 10 | 30.64 | 11.68|
3232
|ShuffleNet-v2-fp32 |[8.79MB](model/shufflenet-v2-12.onnx) |[8.69MB](model/shufflenet-v2-12.tar.gz) |1.9 |12 |33.65 |13.43|
33-
|ShuffleNet-v2-int8 |[2.28MB](model/shufflenet-v2-12-int8.onnx) |[2.37MB](model/shufflenet-v2-10-int8.tar.gz) |1.9 |12 |33.85 |13.66 |
33+
|ShuffleNet-v2-int8 |[2.28MB](model/shufflenet-v2-12-int8.onnx) |[2.37MB](model/shufflenet-v2-12-int8.tar.gz) |1.9 |12 |33.85 |13.66 |
3434
> Compared with the fp32 ShuffleNet-v2, int8 ShuffleNet-v2's Top-1 error rising ratio is 0.59%, Top-5 error rising ratio is 1.71% and performance improvement is 1.62x.
3535
>
3636
> Note the performance depends on the test hardware.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:477a77f7ba31bc6bebe6e9824cb0108173d3bb0c54506d8c6663ea36eee7dfb4
3+
size 105964121

workflow_scripts/test_models.py

Lines changed: 65 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,82 +2,85 @@
22

33
import argparse
44
import check_model
5-
import os
65
from pathlib import Path
76
import subprocess
87
import sys
98
import test_utils
109

1110

1211
def main():
13-
parser = argparse.ArgumentParser(description='Test settings')
14-
# default all: test by both onnx and onnxruntime
15-
# if target is specified, only test by the specified one
16-
parser.add_argument('--target', required=False, default='all', type=str,
17-
help='Test the model by which (onnx/onnxruntime)?',
18-
choices=['onnx', 'onnxruntime', 'all'])
19-
args = parser.parse_args()
12+
parser = argparse.ArgumentParser(description='Test settings')
13+
# default all: test by both onnx and onnxruntime
14+
# if target is specified, only test by the specified one
15+
parser.add_argument('--target', required=False, default='all', type=str,
16+
help='Test the model by which (onnx/onnxruntime)?',
17+
choices=['onnx', 'onnxruntime', 'all'])
18+
args = parser.parse_args()
2019

21-
cwd_path = Path.cwd()
22-
# git fetch first for git diff on GitHub Action
23-
subprocess.run(['git', 'fetch', 'origin', 'main:main'], cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
24-
# obtain list of added or modified files in this PR
25-
obtain_diff = subprocess.Popen(['git', 'diff', '--name-only', '--diff-filter=AM', 'origin/main', 'HEAD'],
26-
cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
27-
stdoutput, stderroutput = obtain_diff.communicate()
28-
diff_list = stdoutput.split()
20+
cwd_path = Path.cwd()
21+
# git fetch first for git diff on GitHub Action
22+
subprocess.run(['git', 'fetch', 'origin', 'main:main'],
23+
cwd=cwd_path, stdout=subprocess.PIPE,
24+
stderr=subprocess.PIPE)
25+
# obtain list of added or modified files in this PR
26+
obtain_diff = subprocess.Popen(['git', 'diff', '--name-only', '--diff-filter=AM', 'origin/main', 'HEAD'],
27+
cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
28+
stdoutput, _ = obtain_diff.communicate()
29+
diff_list = stdoutput.split()
2930

30-
# identify list of changed onnx models in model Zoo
31-
model_list = [str(model).replace("b'","").replace("'", "") for model in diff_list if ".onnx" in str(model)]
32-
# run lfs install before starting the tests
33-
test_utils.run_lfs_install()
31+
# identify list of changed ONNX models in ONXX Model Zoo
32+
tar_ext_name = '.tar.gz'
33+
onnx_ext_name = '.onnx'
34+
model_list = [str(model).replace("b'", "").replace("'", "")
35+
for model in diff_list if onnx_ext_name in str(model) or tar_ext_name in str(model)]
36+
# run lfs install before starting the tests
37+
test_utils.run_lfs_install()
3438

35-
print('\n=== Running ONNX Checker on added models ===\n')
36-
# run checker on each model
37-
failed_models = []
38-
tar_ext_name = '.tar.gz'
39-
for model_path in model_list:
40-
model_name = model_path.split('/')[-1]
41-
tar_name = model_name.replace('.onnx', tar_ext_name)
42-
print('==============Testing {}=============='.format(model_name))
39+
print('\n=== Running ONNX Checker on added models ===\n')
40+
# run checker on each model
41+
failed_models = []
42+
for model_path in model_list:
43+
model_name = model_path.split('/')[-1]
44+
print('==============Testing {}=============='.format(model_name))
4345

44-
try:
45-
# Step 1: check the onnx model and test_data_set from .tar.gz by ORT
46-
# replace '.onnx' with '.tar.gz'
47-
tar_gz_path = model_path[:-5] + '.tar.gz'
48-
print(tar_gz_path)
49-
test_data_set = []
50-
# if tar.gz exists, git pull and try to get test data
51-
if (args.target == 'onnxruntime' or args.target == 'all') and os.path.exists(tar_gz_path):
52-
test_utils.pull_lfs_file(tar_gz_path)
53-
# check whether 'test_data_set_0' exists
54-
model_path_from_tar, test_data_set = test_utils.extract_test_data(tar_gz_path)
55-
# finally check the onnx model from .tar.gz by ORT
56-
# if the test_data_set does not exist, create the test_data_set
57-
check_model.run_backend_ort(model_path_from_tar, test_data_set)
58-
print('[PASS] {} is checked by onnxruntime. '.format(tar_name))
46+
try:
47+
# check .tar.gz by ORT and ONNX
48+
if tar_ext_name in model_name:
49+
# Step 1: check the ONNX model and test_data_set from .tar.gz by ORT
50+
test_data_set = []
51+
test_utils.pull_lfs_file(model_path)
52+
# check whether 'test_data_set_0' exists
53+
model_path_from_tar, test_data_set = test_utils.extract_test_data(model_path)
54+
# if tar.gz exists, git pull and try to get test data
55+
if (args.target == 'onnxruntime' or args.target == 'all'):
56+
# finally check the ONNX model from .tar.gz by ORT
57+
# if the test_data_set does not exist, create the test_data_set
58+
check_model.run_backend_ort(model_path_from_tar, test_data_set)
59+
print('[PASS] {} is checked by onnxruntime. '.format(model_name))
60+
# Step 2: check the ONNX model inside .tar.gz by ONNX
61+
if args.target == 'onnx' or args.target == 'all':
62+
check_model.run_onnx_checker(model_path_from_tar)
63+
print('[PASS] {} is checked by onnx. '.format(model_name))
64+
# check uploaded standalone ONNX model by ONNX
65+
elif onnx_ext_name in model_name:
66+
test_utils.pull_lfs_file(model_path)
67+
if args.target == 'onnx' or args.target == 'all':
68+
check_model.run_onnx_checker(model_path)
69+
print('[PASS] {} is checked by onnx. '.format(model_name))
5970

60-
# Step 2: check the uploaded onnx model by ONNX
61-
# git pull the onnx file
62-
test_utils.pull_lfs_file(model_path)
63-
# 2. check the uploaded onnx model by ONNX
64-
if args.target == 'onnx' or args.target == 'all':
65-
check_model.run_onnx_checker(model_path)
66-
print('[PASS] {} is checked by onnx. '.format(model_name))
71+
except Exception as e:
72+
print('[FAIL] {}: {}'.format(model_name, e))
73+
failed_models.append(model_path)
74+
test_utils.remove_onnxruntime_test_dir()
75+
# remove the produced tar directory
76+
test_utils.remove_tar_dir()
6777

68-
except Exception as e:
69-
print('[FAIL] {}: {}'.format(model_name, e))
70-
failed_models.append(model_path)
71-
test_utils.remove_onnxruntime_test_dir()
78+
if len(failed_models) == 0:
79+
print('{} models have been checked. '.format(len(model_list)))
80+
else:
81+
print('In all {} models, {} models failed. '.format(len(model_list), len(failed_models)))
82+
sys.exit(1)
7283

73-
# remove the produced tar directory
74-
test_utils.remove_tar_dir()
75-
76-
if len(failed_models) == 0:
77-
print('{} models have been checked. '.format(len(model_list)))
78-
else:
79-
print('In all {} models, {} models failed. '.format(len(model_list), len(failed_models)))
80-
sys.exit(1)
8184

8285
if __name__ == '__main__':
8386
main()

0 commit comments

Comments
 (0)