|
10 | 10 | import test_utils |
11 | 11 |
|
12 | 12 |
|
13 | | -def has_vnni_support(): |
14 | | - return 'avx512' in str(get_cpu_info()['flags']) |
15 | | - |
16 | | - |
17 | | -def skip_quant_models_if_missing_vnni(model_name): |
18 | | - return ('-int8' in model_name or '-qdq' in model_name) and not has_vnni_support() |
| 13 | +def has_avx512f_support(): |
| 14 | + return "avx512f" in set(get_cpu_info()["flags"]) |
19 | 15 |
|
20 | 16 |
|
21 | 17 | def run_onnx_checker(model_path): |
22 | 18 | model = onnx.load(model_path) |
23 | 19 | onnx.checker.check_model(model) |
24 | 20 |
|
25 | 21 |
|
| 22 | +def ort_skip_reason(model_path): |
| 23 | + if ("-int8" in model_path or "-qdq" in model_path) and not has_avx512f_support(): |
| 24 | + return f"Skip ORT test for {model_path} because this machine lacks avx512f support and the output.pb was produced with avx512f support." |
| 25 | + model = onnx.load(model_path) |
| 26 | + if model.opset_import[0].version < 7: |
| 27 | + return f"Skip ORT test for {model_path} because ORT only supports opset version >= 7" |
| 28 | + return None |
| 29 | + |
| 30 | + |
26 | 31 | def make_tarfile(output_filename, source_dir): |
27 | 32 | with tarfile.open(output_filename, "w:gz", format=tarfile.GNU_FORMAT) as tar: |
28 | 33 | tar.add(source_dir, arcname=os.path.basename(source_dir)) |
29 | 34 |
|
30 | 35 |
|
31 | 36 | def run_backend_ort(model_path, test_data_set=None, tar_gz_path=None): |
32 | | - if skip_quant_models_if_missing_vnni(model_path): |
33 | | - print(f'Skip ORT test for {model_path} because this machine lacks of VNNI support and the output.pb was produced with VNNI support.') |
34 | | - return |
35 | | - model = onnx.load(model_path) |
36 | | - if model.opset_import[0].version < 7: |
37 | | - print('Skip ORT test since it only *guarantees* support for models stamped with opset version 7') |
| 37 | + skip_reason = ort_skip_reason(model_path) |
| 38 | + if skip_reason: |
| 39 | + print(skip_reason) |
38 | 40 | return |
39 | | - # if 'test_data_set_N' doesn't exist, create test_dir |
| 41 | + # if "test_data_set_N" doesn't exist, create test_dir |
40 | 42 | if not test_data_set: |
41 | 43 | # Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers |
42 | 44 | # other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default |
43 | 45 | # based on the build flags) when instantiating InferenceSession. |
44 | 46 | # For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following: |
45 | | - # onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider']) |
| 47 | + # onnxruntime.InferenceSession(path/to/model, providers=["CUDAExecutionProvider"]) |
46 | 48 | onnxruntime.InferenceSession(model_path) |
47 | 49 | # Get model name without .onnx |
48 | 50 | model_name = os.path.basename(os.path.splitext(model_path)[0]) |
49 | 51 | if model_name is None: |
50 | 52 | print(f"The model path {model_path} is invalid") |
51 | 53 | return |
52 | | - ort_test_dir_utils.create_test_dir(model_path, './', test_utils.TEST_ORT_DIR) |
| 54 | + ort_test_dir_utils.create_test_dir(model_path, "./", test_utils.TEST_ORT_DIR) |
53 | 55 | ort_test_dir_utils.run_test_dir(test_utils.TEST_ORT_DIR) |
54 | 56 | if os.path.exists(model_name) and os.path.isdir(model_name): |
55 | 57 | rmtree(model_name) |
56 | 58 | os.rename(test_utils.TEST_ORT_DIR, model_name) |
57 | 59 | make_tarfile(tar_gz_path, model_name) |
58 | 60 | rmtree(model_name) |
59 | | - # otherwise use the existing 'test_data_set_N' as test data |
| 61 | + # otherwise use the existing "test_data_set_N" as test data |
60 | 62 | else: |
61 | 63 | test_dir_from_tar = test_utils.get_model_directory(model_path) |
62 | 64 | ort_test_dir_utils.run_test_dir(test_dir_from_tar) |
|
0 commit comments