|
2 | 2 |
|
3 | 3 | import argparse
|
4 | 4 | import check_model
|
5 |
| -import os |
6 | 5 | from pathlib import Path
|
7 | 6 | import subprocess
|
8 | 7 | import sys
|
9 | 8 | import test_utils
|
10 | 9 |
|
11 | 10 |
|
12 | 11 | def main():
|
13 |
| - parser = argparse.ArgumentParser(description='Test settings') |
14 |
| - # default all: test by both onnx and onnxruntime |
15 |
| - # if target is specified, only test by the specified one |
16 |
| - parser.add_argument('--target', required=False, default='all', type=str, |
17 |
| - help='Test the model by which (onnx/onnxruntime)?', |
18 |
| - choices=['onnx', 'onnxruntime', 'all']) |
19 |
| - args = parser.parse_args() |
| 12 | + parser = argparse.ArgumentParser(description='Test settings') |
| 13 | + # default all: test by both onnx and onnxruntime |
| 14 | + # if target is specified, only test by the specified one |
| 15 | + parser.add_argument('--target', required=False, default='all', type=str, |
| 16 | + help='Test the model by which (onnx/onnxruntime)?', |
| 17 | + choices=['onnx', 'onnxruntime', 'all']) |
| 18 | + args = parser.parse_args() |
20 | 19 |
|
21 |
| - cwd_path = Path.cwd() |
22 |
| - # git fetch first for git diff on GitHub Action |
23 |
| - subprocess.run(['git', 'fetch', 'origin', 'main:main'], cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
24 |
| - # obtain list of added or modified files in this PR |
25 |
| - obtain_diff = subprocess.Popen(['git', 'diff', '--name-only', '--diff-filter=AM', 'origin/main', 'HEAD'], |
26 |
| - cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
27 |
| - stdoutput, stderroutput = obtain_diff.communicate() |
28 |
| - diff_list = stdoutput.split() |
| 20 | + cwd_path = Path.cwd() |
| 21 | + # git fetch first for git diff on GitHub Action |
| 22 | + subprocess.run(['git', 'fetch', 'origin', 'main:main'], |
| 23 | + cwd=cwd_path, stdout=subprocess.PIPE, |
| 24 | + stderr=subprocess.PIPE) |
| 25 | + # obtain list of added or modified files in this PR |
| 26 | + obtain_diff = subprocess.Popen(['git', 'diff', '--name-only', '--diff-filter=AM', 'origin/main', 'HEAD'], |
| 27 | + cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| 28 | + stdoutput, _ = obtain_diff.communicate() |
| 29 | + diff_list = stdoutput.split() |
29 | 30 |
|
30 |
| - # identify list of changed onnx models in model Zoo |
31 |
| - model_list = [str(model).replace("b'","").replace("'", "") for model in diff_list if ".onnx" in str(model)] |
32 |
| - # run lfs install before starting the tests |
33 |
| - test_utils.run_lfs_install() |
| 31 | + # identify list of changed ONNX models in ONXX Model Zoo |
| 32 | + tar_ext_name = '.tar.gz' |
| 33 | + onnx_ext_name = '.onnx' |
| 34 | + model_list = [str(model).replace("b'", "").replace("'", "") |
| 35 | + for model in diff_list if onnx_ext_name in str(model) or tar_ext_name in str(model)] |
| 36 | + # run lfs install before starting the tests |
| 37 | + test_utils.run_lfs_install() |
34 | 38 |
|
35 |
| - print('\n=== Running ONNX Checker on added models ===\n') |
36 |
| - # run checker on each model |
37 |
| - failed_models = [] |
38 |
| - tar_ext_name = '.tar.gz' |
39 |
| - for model_path in model_list: |
40 |
| - model_name = model_path.split('/')[-1] |
41 |
| - tar_name = model_name.replace('.onnx', tar_ext_name) |
42 |
| - print('==============Testing {}=============='.format(model_name)) |
| 39 | + print('\n=== Running ONNX Checker on added models ===\n') |
| 40 | + # run checker on each model |
| 41 | + failed_models = [] |
| 42 | + for model_path in model_list: |
| 43 | + model_name = model_path.split('/')[-1] |
| 44 | + print('==============Testing {}=============='.format(model_name)) |
43 | 45 |
|
44 |
| - try: |
45 |
| - # Step 1: check the onnx model and test_data_set from .tar.gz by ORT |
46 |
| - # replace '.onnx' with '.tar.gz' |
47 |
| - tar_gz_path = model_path[:-5] + '.tar.gz' |
48 |
| - print(tar_gz_path) |
49 |
| - test_data_set = [] |
50 |
| - # if tar.gz exists, git pull and try to get test data |
51 |
| - if (args.target == 'onnxruntime' or args.target == 'all') and os.path.exists(tar_gz_path): |
52 |
| - test_utils.pull_lfs_file(tar_gz_path) |
53 |
| - # check whether 'test_data_set_0' exists |
54 |
| - model_path_from_tar, test_data_set = test_utils.extract_test_data(tar_gz_path) |
55 |
| - # finally check the onnx model from .tar.gz by ORT |
56 |
| - # if the test_data_set does not exist, create the test_data_set |
57 |
| - check_model.run_backend_ort(model_path_from_tar, test_data_set) |
58 |
| - print('[PASS] {} is checked by onnxruntime. '.format(tar_name)) |
| 46 | + try: |
| 47 | + # check .tar.gz by ORT and ONNX |
| 48 | + if tar_ext_name in model_name: |
| 49 | + # Step 1: check the ONNX model and test_data_set from .tar.gz by ORT |
| 50 | + test_data_set = [] |
| 51 | + test_utils.pull_lfs_file(model_path) |
| 52 | + # check whether 'test_data_set_0' exists |
| 53 | + model_path_from_tar, test_data_set = test_utils.extract_test_data(model_path) |
| 54 | + # if tar.gz exists, git pull and try to get test data |
| 55 | + if (args.target == 'onnxruntime' or args.target == 'all'): |
| 56 | + # finally check the ONNX model from .tar.gz by ORT |
| 57 | + # if the test_data_set does not exist, create the test_data_set |
| 58 | + check_model.run_backend_ort(model_path_from_tar, test_data_set) |
| 59 | + print('[PASS] {} is checked by onnxruntime. '.format(model_name)) |
| 60 | + # Step 2: check the ONNX model inside .tar.gz by ONNX |
| 61 | + if args.target == 'onnx' or args.target == 'all': |
| 62 | + check_model.run_onnx_checker(model_path_from_tar) |
| 63 | + print('[PASS] {} is checked by onnx. '.format(model_name)) |
| 64 | + # check uploaded standalone ONNX model by ONNX |
| 65 | + elif onnx_ext_name in model_name: |
| 66 | + test_utils.pull_lfs_file(model_path) |
| 67 | + if args.target == 'onnx' or args.target == 'all': |
| 68 | + check_model.run_onnx_checker(model_path) |
| 69 | + print('[PASS] {} is checked by onnx. '.format(model_name)) |
59 | 70 |
|
60 |
| - # Step 2: check the uploaded onnx model by ONNX |
61 |
| - # git pull the onnx file |
62 |
| - test_utils.pull_lfs_file(model_path) |
63 |
| - # 2. check the uploaded onnx model by ONNX |
64 |
| - if args.target == 'onnx' or args.target == 'all': |
65 |
| - check_model.run_onnx_checker(model_path) |
66 |
| - print('[PASS] {} is checked by onnx. '.format(model_name)) |
| 71 | + except Exception as e: |
| 72 | + print('[FAIL] {}: {}'.format(model_name, e)) |
| 73 | + failed_models.append(model_path) |
| 74 | + test_utils.remove_onnxruntime_test_dir() |
| 75 | + # remove the produced tar directory |
| 76 | + test_utils.remove_tar_dir() |
67 | 77 |
|
68 |
| - except Exception as e: |
69 |
| - print('[FAIL] {}: {}'.format(model_name, e)) |
70 |
| - failed_models.append(model_path) |
71 |
| - test_utils.remove_onnxruntime_test_dir() |
| 78 | + if len(failed_models) == 0: |
| 79 | + print('{} models have been checked. '.format(len(model_list))) |
| 80 | + else: |
| 81 | + print('In all {} models, {} models failed. '.format(len(model_list), len(failed_models))) |
| 82 | + sys.exit(1) |
72 | 83 |
|
73 |
| - # remove the produced tar directory |
74 |
| - test_utils.remove_tar_dir() |
75 |
| - |
76 |
| - if len(failed_models) == 0: |
77 |
| - print('{} models have been checked. '.format(len(model_list))) |
78 |
| - else: |
79 |
| - print('In all {} models, {} models failed. '.format(len(model_list), len(failed_models))) |
80 |
| - sys.exit(1) |
81 | 84 |
|
82 | 85 | if __name__ == '__main__':
|
83 | 86 | main()
|
0 commit comments