Skip to content

Commit bf28aa8

Browse files
jcwchengarymm
andauthored
Use avx512f instead of VNNI and refactor code in CIs (#533)
* Use avx512f instead of vnni and refactor code Signed-off-by: jcwchen <[email protected]> * Update workflow_scripts/check_model.py Signed-off-by: jcwchen <[email protected]> Co-authored-by: Gary Miguel <[email protected]> Signed-off-by: jcwchen <[email protected]> * fix flake8 Signed-off-by: jcwchen <[email protected]> * typo Signed-off-by: jcwchen <[email protected]> Signed-off-by: jcwchen <[email protected]> Co-authored-by: Gary Miguel <[email protected]>
1 parent 801bd0c commit bf28aa8

File tree

4 files changed

+24
-20
lines changed

4 files changed

+24
-20
lines changed

.github/workflows/linux_ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
run: |
3535
python -m pip install --upgrade pip
3636
python -m pip install onnx onnxruntime requests py-cpuinfo
37+
# Print CPU info for debugging ONNX Runtime inference difference
3738
python -m cpuinfo
3839
3940
- name: Test updated ONNX_HUB_MANIFEST.json

.github/workflows/windows_ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
# TODO: now ONNX only supports Protobuf <= 3.20.1
3737
python -m pip install protobuf==3.20.1
3838
python -m pip install onnx onnxruntime requests py-cpuinfo
39+
# Print CPU info for debugging ONNX Runtime inference difference
3940
python -m cpuinfo
4041
4142
- name: Test new models by onnx

workflow_scripts/check_model.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,53 +10,55 @@
1010
import test_utils
1111

1212

13-
def has_vnni_support():
14-
return 'avx512' in str(get_cpu_info()['flags'])
15-
16-
17-
def skip_quant_models_if_missing_vnni(model_name):
18-
return ('-int8' in model_name or '-qdq' in model_name) and not has_vnni_support()
13+
def has_avx512f_support():
14+
return "avx512f" in set(get_cpu_info()["flags"])
1915

2016

2117
def run_onnx_checker(model_path):
2218
model = onnx.load(model_path)
2319
onnx.checker.check_model(model)
2420

2521

22+
def ort_skip_reason(model_path):
23+
if ("-int8" in model_path or "-qdq" in model_path) and not has_avx512f_support():
24+
return f"Skip ORT test for {model_path} because this machine lacks avx512f support and the output.pb was produced with avx512f support."
25+
model = onnx.load(model_path)
26+
if model.opset_import[0].version < 7:
27+
return f"Skip ORT test for {model_path} because ORT only supports opset version >= 7"
28+
return None
29+
30+
2631
def make_tarfile(output_filename, source_dir):
2732
with tarfile.open(output_filename, "w:gz", format=tarfile.GNU_FORMAT) as tar:
2833
tar.add(source_dir, arcname=os.path.basename(source_dir))
2934

3035

3136
def run_backend_ort(model_path, test_data_set=None, tar_gz_path=None):
32-
if skip_quant_models_if_missing_vnni(model_path):
33-
print(f'Skip ORT test for {model_path} because this machine lacks of VNNI support and the output.pb was produced with VNNI support.')
34-
return
35-
model = onnx.load(model_path)
36-
if model.opset_import[0].version < 7:
37-
print('Skip ORT test since it only *guarantees* support for models stamped with opset version 7')
37+
skip_reason = ort_skip_reason(model_path)
38+
if skip_reason:
39+
print(skip_reason)
3840
return
39-
# if 'test_data_set_N' doesn't exist, create test_dir
41+
# if "test_data_set_N" doesn't exist, create test_dir
4042
if not test_data_set:
4143
# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
4244
# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
4345
# based on the build flags) when instantiating InferenceSession.
4446
# For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:
45-
# onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider'])
47+
# onnxruntime.InferenceSession(path/to/model, providers=["CUDAExecutionProvider"])
4648
onnxruntime.InferenceSession(model_path)
4749
# Get model name without .onnx
4850
model_name = os.path.basename(os.path.splitext(model_path)[0])
4951
if model_name is None:
5052
print(f"The model path {model_path} is invalid")
5153
return
52-
ort_test_dir_utils.create_test_dir(model_path, './', test_utils.TEST_ORT_DIR)
54+
ort_test_dir_utils.create_test_dir(model_path, "./", test_utils.TEST_ORT_DIR)
5355
ort_test_dir_utils.run_test_dir(test_utils.TEST_ORT_DIR)
5456
if os.path.exists(model_name) and os.path.isdir(model_name):
5557
rmtree(model_name)
5658
os.rename(test_utils.TEST_ORT_DIR, model_name)
5759
make_tarfile(tar_gz_path, model_name)
5860
rmtree(model_name)
59-
# otherwise use the existing 'test_data_set_N' as test data
61+
# otherwise use the existing "test_data_set_N" as test data
6062
else:
6163
test_dir_from_tar = test_utils.get_model_directory(model_path)
6264
ort_test_dir_utils.run_test_dir(test_dir_from_tar)

workflow_scripts/generate_onnx_hub_manifest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def prep_name(col):
100100
def get_file_info(row, field, target_models=None):
101101
source_dir = split(row["source_file"])[0]
102102
model_file = row[field].contents[0].attrs["href"]
103-
## So that model relative path is consistent across OS
103+
# So that model relative path is consistent across OS
104104
rel_path = "/".join(join(source_dir, model_file).split(os.sep))
105105
if target_models is not None and rel_path not in target_models:
106106
return None
@@ -261,7 +261,7 @@ def get_model_ports(source_file, metadata, model_name):
261261
for k, v in get_file_info(row, "model_with_data_path").items():
262262
metadata[k] = v
263263
except (AttributeError, FileNotFoundError) as e:
264-
print("no model_with_data in file {}".format(row["source_file"]))
264+
print("no model_with_data in file {}: {}".format(row["source_file"], e))
265265

266266
try:
267267
opset = int(row["opset_version"].contents[0])
@@ -291,7 +291,7 @@ def get_model_ports(source_file, metadata, model_name):
291291

292292
else:
293293
print("Missing model in {}".format(row["source_file"]))
294-
output.sort(key=lambda x:x["model_path"])
295-
with open( "ONNX_HUB_MANIFEST.json", "w+") as f:
294+
output.sort(key=lambda x: x["model_path"])
295+
with open("ONNX_HUB_MANIFEST.json", "w+") as f:
296296
print("Found {} models".format(len(output)))
297297
json.dump(output, f, indent=4)

0 commit comments

Comments
 (0)