|
2 | 2 |
|
3 | 3 | import argparse |
4 | 4 | import check_model |
5 | | -import os |
6 | 5 | from pathlib import Path |
7 | 6 | import subprocess |
8 | 7 | import sys |
9 | 8 | import test_utils |
10 | 9 |
|
11 | 10 |
|
12 | 11 | def main(): |
13 | | - parser = argparse.ArgumentParser(description='Test settings') |
14 | | - # default all: test by both onnx and onnxruntime |
15 | | - # if target is specified, only test by the specified one |
16 | | - parser.add_argument('--target', required=False, default='all', type=str, |
17 | | - help='Test the model by which (onnx/onnxruntime)?', |
18 | | - choices=['onnx', 'onnxruntime', 'all']) |
19 | | - args = parser.parse_args() |
| 12 | + parser = argparse.ArgumentParser(description='Test settings') |
| 13 | + # default all: test by both onnx and onnxruntime |
| 14 | + # if target is specified, only test by the specified one |
| 15 | + parser.add_argument('--target', required=False, default='all', type=str, |
| 16 | + help='Test the model by which (onnx/onnxruntime)?', |
| 17 | + choices=['onnx', 'onnxruntime', 'all']) |
| 18 | + args = parser.parse_args() |
20 | 19 |
|
21 | | - cwd_path = Path.cwd() |
22 | | - # git fetch first for git diff on GitHub Action |
23 | | - subprocess.run(['git', 'fetch', 'origin', 'main:main'], cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
24 | | - # obtain list of added or modified files in this PR |
25 | | - obtain_diff = subprocess.Popen(['git', 'diff', '--name-only', '--diff-filter=AM', 'origin/main', 'HEAD'], |
26 | | - cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
27 | | - stdoutput, stderroutput = obtain_diff.communicate() |
28 | | - diff_list = stdoutput.split() |
| 20 | + cwd_path = Path.cwd() |
| 21 | + # git fetch first for git diff on GitHub Action |
| 22 | + subprocess.run(['git', 'fetch', 'origin', 'main:main'], |
| 23 | + cwd=cwd_path, stdout=subprocess.PIPE, |
| 24 | + stderr=subprocess.PIPE) |
| 25 | + # obtain list of added or modified files in this PR |
| 26 | + obtain_diff = subprocess.Popen(['git', 'diff', '--name-only', '--diff-filter=AM', 'origin/main', 'HEAD'], |
| 27 | + cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| 28 | + stdoutput, _ = obtain_diff.communicate() |
| 29 | + diff_list = stdoutput.split() |
29 | 30 |
|
30 | | - # identify list of changed onnx models in model Zoo |
31 | | - model_list = [str(model).replace("b'","").replace("'", "") for model in diff_list if ".onnx" in str(model)] |
32 | | - # run lfs install before starting the tests |
33 | | - test_utils.run_lfs_install() |
| 31 | + # identify list of changed ONNX models in ONXX Model Zoo |
| 32 | + tar_ext_name = '.tar.gz' |
| 33 | + onnx_ext_name = '.onnx' |
| 34 | + model_list = [str(model).replace("b'", "").replace("'", "") |
| 35 | + for model in diff_list if onnx_ext_name in str(model) or tar_ext_name in str(model)] |
| 36 | + # run lfs install before starting the tests |
| 37 | + test_utils.run_lfs_install() |
34 | 38 |
|
35 | | - print('\n=== Running ONNX Checker on added models ===\n') |
36 | | - # run checker on each model |
37 | | - failed_models = [] |
38 | | - tar_ext_name = '.tar.gz' |
39 | | - for model_path in model_list: |
40 | | - model_name = model_path.split('/')[-1] |
41 | | - tar_name = model_name.replace('.onnx', tar_ext_name) |
42 | | - print('==============Testing {}=============='.format(model_name)) |
| 39 | + print('\n=== Running ONNX Checker on added models ===\n') |
| 40 | + # run checker on each model |
| 41 | + failed_models = [] |
| 42 | + for model_path in model_list: |
| 43 | + model_name = model_path.split('/')[-1] |
| 44 | + print('==============Testing {}=============='.format(model_name)) |
43 | 45 |
|
44 | | - try: |
45 | | - # Step 1: check the onnx model and test_data_set from .tar.gz by ORT |
46 | | - # replace '.onnx' with '.tar.gz' |
47 | | - tar_gz_path = model_path[:-5] + '.tar.gz' |
48 | | - print(tar_gz_path) |
49 | | - test_data_set = [] |
50 | | - # if tar.gz exists, git pull and try to get test data |
51 | | - if (args.target == 'onnxruntime' or args.target == 'all') and os.path.exists(tar_gz_path): |
52 | | - test_utils.pull_lfs_file(tar_gz_path) |
53 | | - # check whether 'test_data_set_0' exists |
54 | | - model_path_from_tar, test_data_set = test_utils.extract_test_data(tar_gz_path) |
55 | | - # finally check the onnx model from .tar.gz by ORT |
56 | | - # if the test_data_set does not exist, create the test_data_set |
57 | | - check_model.run_backend_ort(model_path_from_tar, test_data_set) |
58 | | - print('[PASS] {} is checked by onnxruntime. '.format(tar_name)) |
| 46 | + try: |
| 47 | + # check .tar.gz by ORT and ONNX |
| 48 | + if tar_ext_name in model_name: |
| 49 | + # Step 1: check the ONNX model and test_data_set from .tar.gz by ORT |
| 50 | + test_data_set = [] |
| 51 | + test_utils.pull_lfs_file(model_path) |
| 52 | + # check whether 'test_data_set_0' exists |
| 53 | + model_path_from_tar, test_data_set = test_utils.extract_test_data(model_path) |
| 54 | + # if tar.gz exists, git pull and try to get test data |
| 55 | + if (args.target == 'onnxruntime' or args.target == 'all'): |
| 56 | + # finally check the ONNX model from .tar.gz by ORT |
| 57 | + # if the test_data_set does not exist, create the test_data_set |
| 58 | + check_model.run_backend_ort(model_path_from_tar, test_data_set) |
| 59 | + print('[PASS] {} is checked by onnxruntime. '.format(model_name)) |
| 60 | + # Step 2: check the ONNX model inside .tar.gz by ONNX |
| 61 | + if args.target == 'onnx' or args.target == 'all': |
| 62 | + check_model.run_onnx_checker(model_path_from_tar) |
| 63 | + print('[PASS] {} is checked by onnx. '.format(model_name)) |
| 64 | + # check uploaded standalone ONNX model by ONNX |
| 65 | + elif onnx_ext_name in model_name: |
| 66 | + test_utils.pull_lfs_file(model_path) |
| 67 | + if args.target == 'onnx' or args.target == 'all': |
| 68 | + check_model.run_onnx_checker(model_path) |
| 69 | + print('[PASS] {} is checked by onnx. '.format(model_name)) |
59 | 70 |
|
60 | | - # Step 2: check the uploaded onnx model by ONNX |
61 | | - # git pull the onnx file |
62 | | - test_utils.pull_lfs_file(model_path) |
63 | | - # 2. check the uploaded onnx model by ONNX |
64 | | - if args.target == 'onnx' or args.target == 'all': |
65 | | - check_model.run_onnx_checker(model_path) |
66 | | - print('[PASS] {} is checked by onnx. '.format(model_name)) |
| 71 | + except Exception as e: |
| 72 | + print('[FAIL] {}: {}'.format(model_name, e)) |
| 73 | + failed_models.append(model_path) |
| 74 | + test_utils.remove_onnxruntime_test_dir() |
| 75 | + # remove the produced tar directory |
| 76 | + test_utils.remove_tar_dir() |
67 | 77 |
|
68 | | - except Exception as e: |
69 | | - print('[FAIL] {}: {}'.format(model_name, e)) |
70 | | - failed_models.append(model_path) |
71 | | - test_utils.remove_onnxruntime_test_dir() |
| 78 | + if len(failed_models) == 0: |
| 79 | + print('{} models have been checked. '.format(len(model_list))) |
| 80 | + else: |
| 81 | + print('In all {} models, {} models failed. '.format(len(model_list), len(failed_models))) |
| 82 | + sys.exit(1) |
72 | 83 |
|
73 | | - # remove the produced tar directory |
74 | | - test_utils.remove_tar_dir() |
75 | | - |
76 | | - if len(failed_models) == 0: |
77 | | - print('{} models have been checked. '.format(len(model_list))) |
78 | | - else: |
79 | | - print('In all {} models, {} models failed. '.format(len(model_list), len(failed_models))) |
80 | | - sys.exit(1) |
81 | 84 |
|
82 | 85 | if __name__ == '__main__': |
83 | 86 | main() |
0 commit comments