Skip to content

Commit ed9cb4f

Browse files
authored
Merge pull request #324 from nbcsm/ci
fix CI failure, refine test code
2 parents 15ac39c + 2dce373 commit ed9cb4f

18 files changed

+455
-264
lines changed

ci_build/azure_pipelines/pretrained_model_test-matrix.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ jobs:
55
parameters:
66
platforms: ['linux', 'windows', 'mac']
77
tf_versions: ['1.12', '1.11', '1.10', '1.9', '1.8', '1.7', '1.6', '1.5']
8-
onnx_versions: ['1.3']
98
onnx_opsets: ['8', '7']
109
onnx_backends:
1110
onnxruntime: ['0.2.1']

ci_build/azure_pipelines/templates/setup.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ steps:
55
set -ex
66
pip install pytest pytest-cov pytest-runner graphviz requests pyyaml pillow pandas
77
pip install $(CI_PIP_TF_NAME) $(CI_PIP_ONNX_NAME) $(CI_PIP_ONNX_BACKEND_NAME)
8+
9+
# TF 1.10 requires numpy <=1.14.5 and >=1.13.3, but onnxruntime 0.2.1 does not work with numpy <= 1.14.5
10+
# Upgrade numpy only within constraints from other packages if any.
11+
if [[ $CI_TF_VERSION == 1.10* ]] && [[ $CI_ONNX_BACKEND == "onnxruntime" ]] ;
12+
then
13+
pip install $(CI_PIP_ONNX_NAME) $(CI_PIP_ONNX_BACKEND_NAME) numpy --no-deps -U
14+
fi
15+
816
python setup.py install
917
pip freeze --all
1018
displayName: 'Setup Environment'

ci_build/azure_pipelines/unit_test-matrix.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ jobs:
55
parameters:
66
platforms: ['linux', 'windows', 'mac']
77
tf_versions: ['1.12', '1.11', '1.10', '1.9', '1.8', '1.7', '1.6', '1.5']
8-
onnx_versions: ['1.3']
98
onnx_opsets: ['8', '7']
109
onnx_backends:
1110
onnxruntime: ['0.2.1']

setup.cfg

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,4 @@
22
test=pytest
33

44
[tool:pytest]
5-
addopts=--cov=tf2onnx --ignore=tests/test_custom_rnncell.py --ignore=tests/test_const_fold.py --ignore=tests/test_loops.py
6-
#testpaths=tests/test_*.py
5+
addopts=--cov=tf2onnx

tests/backend_test_base.py

Lines changed: 17 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,31 @@
88
from __future__ import print_function
99
from __future__ import unicode_literals
1010

11-
import argparse
1211
import logging
1312
import os
14-
import sys
15-
import tempfile
1613
import unittest
1714

1815
import numpy as np
1916
import tensorflow as tf
2017
from tensorflow.python.ops import variables as variables_lib
18+
from common import get_test_config
2119
from tf2onnx import utils
22-
from tf2onnx.tfonnx import process_tf_graph, tf_optimize, DEFAULT_TARGET, POSSIBLE_TARGETS
20+
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
2321

2422

2523
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
2624

2725
class Tf2OnnxBackendTestBase(unittest.TestCase):
28-
# static variables
29-
TMPPATH = tempfile.mkdtemp()
30-
BACKEND = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
31-
OPSET = int(os.environ.get("TF2ONNX_TEST_OPSET", 7))
32-
TARGET = os.environ.get("TF2ONNX_TEST_TARGET", "").split(",")
33-
DEBUG = None
34-
35-
def debug_mode(self):
36-
return type(self).DEBUG
37-
3826
def setUp(self):
27+
self.config = get_test_config()
3928
self.maxDiff = None
4029
tf.reset_default_graph()
4130
# reset name generation on every test
4231
utils.INTERNAL_NAME = 1
4332
np.random.seed(1) # Make it reproducible.
4433

4534
self.log = logging.getLogger("tf2onnx.unitest." + str(type(self)))
46-
if self.debug_mode():
35+
if self.config.is_debug_mode:
4736
self.log.setLevel(logging.DEBUG)
4837
else:
4938
# suppress log info of tensorflow so that result of test can be seen much easier
@@ -83,17 +72,17 @@ def run_onnxruntime(self, model_path, inputs, output_names):
8372
def _run_backend(self, g, outputs, input_dict):
8473
model_proto = g.make_model("test")
8574
model_path = self.save_onnx_model(model_proto, input_dict)
86-
if type(self).BACKEND == "onnxmsrtnext":
75+
if self.config.backend == "onnxmsrtnext":
8776
y = self.run_onnxmsrtnext(model_path, input_dict, outputs)
88-
elif type(self).BACKEND == "onnxruntime":
77+
elif self.config.backend == "onnxruntime":
8978
y = self.run_onnxruntime(model_path, input_dict, outputs)
90-
elif type(self).BACKEND == "caffe2":
79+
elif self.config.backend == "caffe2":
9180
y = self.run_onnxcaffe2(model_proto, input_dict)
9281
else:
9382
raise ValueError("unknown backend")
9483
return y
9584

96-
def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07,
85+
def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=0.,
9786
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=False,
9887
check_dtype=False, process_args=None, onnx_feed_dict=None):
9988
# optional - passed to process_tf_graph
@@ -104,7 +93,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
10493
onnx_feed_dict = feed_dict
10594

10695
graph_def = None
107-
save_dir = os.path.join(type(self).TMPPATH, self._testMethodName)
96+
save_dir = os.path.join(self.config.temp_path, self._testMethodName)
10897

10998
if convert_var_to_const:
11099
with tf.Session() as sess:
@@ -123,7 +112,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
123112
output_dict.append(sess.graph.get_tensor_by_name(out_name))
124113
expected = sess.run(output_dict, feed_dict=feed_dict)
125114

126-
if self.debug_mode():
115+
if self.config.is_debug_mode:
127116
if not os.path.exists(save_dir):
128117
os.makedirs(save_dir)
129118
model_path = os.path.join(save_dir, self._testMethodName + "_original.pb")
@@ -134,7 +123,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
134123
graph_def = tf_optimize(input_names_with_port, output_names_with_port,
135124
sess.graph_def, constant_fold)
136125

137-
if self.debug_mode() and constant_fold:
126+
if self.config.is_debug_mode and constant_fold:
138127
model_path = os.path.join(save_dir, self._testMethodName + "_after_tf_optimize.pb")
139128
with open(model_path, "wb") as f:
140129
f.write(graph_def.SerializeToString())
@@ -144,45 +133,23 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
144133
tf.import_graph_def(graph_def, name='')
145134

146135
with tf.Session() as sess:
147-
g = process_tf_graph(sess.graph, opset=type(self).OPSET, output_names=output_names_with_port,
148-
target=type(self).TARGET, **process_args)
136+
g = process_tf_graph(sess.graph, opset=self.config.opset, output_names=output_names_with_port,
137+
target=self.config.target, **process_args)
149138
actual = self._run_backend(g, output_names_with_port, onnx_feed_dict)
150139

151140
for expected_val, actual_val in zip(expected, actual):
152141
if check_value:
153-
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=0.)
142+
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=atol)
154143
if check_dtype:
155144
self.assertEqual(expected_val.dtype, actual_val.dtype)
156145
if check_shape:
157146
self.assertEqual(expected_val.shape, actual_val.shape)
158147

159148
def save_onnx_model(self, model_proto, feed_dict, postfix=""):
160-
save_path = os.path.join(type(self).TMPPATH, self._testMethodName)
149+
save_path = os.path.join(self.config.temp_path, self._testMethodName)
161150
target_path = utils.save_onnx_model(save_path, self._testMethodName + postfix, feed_dict, model_proto,
162-
include_test_data=self.debug_mode(), as_text=self.debug_mode())
151+
include_test_data=self.config.is_debug_mode,
152+
as_text=self.config.is_debug_mode)
163153

164154
self.log.debug("create model file: %s", target_path)
165155
return target_path
166-
167-
@staticmethod
168-
def trigger(ut_class):
169-
parser = argparse.ArgumentParser()
170-
parser.add_argument('--backend', default=Tf2OnnxBackendTestBase.BACKEND,
171-
choices=["caffe2", "onnxmsrtnext", "onnxruntime"],
172-
help="backend to test against")
173-
parser.add_argument('--opset', type=int, default=Tf2OnnxBackendTestBase.OPSET, help="opset to test against")
174-
parser.add_argument("--target", default=",".join(DEFAULT_TARGET), choices=POSSIBLE_TARGETS,
175-
help="target platform")
176-
parser.add_argument("--debug", help="output debugging information", action="store_true")
177-
parser.add_argument('unittest_args', nargs='*')
178-
179-
args = parser.parse_args()
180-
print(args)
181-
ut_class.BACKEND = args.backend
182-
ut_class.OPSET = args.opset
183-
ut_class.DEBUG = args.debug
184-
ut_class.TARGET = args.target
185-
186-
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
187-
sys.argv[1:] = args.unittest_args
188-
unittest.main()

tests/common.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
""" test common utilities."""
5+
6+
import argparse
7+
import os
8+
import sys
9+
import tempfile
10+
import unittest
11+
12+
from distutils.version import LooseVersion
13+
from tf2onnx.tfonnx import DEFAULT_TARGET, POSSIBLE_TARGETS
14+
15+
__all__ = ["TestConfig", "get_test_config", "unittest_main",
16+
"check_tf_min_version", "check_opset_min_version", "check_target", "skip_onnxruntime_backend",
17+
"skip_caffe2_backend", "check_onnxruntime_incompatibility"]
18+
19+
20+
# pylint: disable=missing-docstring
21+
22+
class TestConfig(object):
23+
def __init__(self):
24+
self.platform = sys.platform
25+
self.tf_version = self._get_tf_version()
26+
self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", 7))
27+
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(DEFAULT_TARGET)).split(',')
28+
self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
29+
self.backend_version = self._get_backend_version()
30+
self.is_debug_mode = False
31+
self.temp_path = tempfile.mkdtemp()
32+
33+
@property
34+
def is_mac(self):
35+
return self.platform == "darwin"
36+
37+
@property
38+
def is_onnxruntime_backend(self):
39+
return self.backend == "onnxruntime"
40+
41+
@property
42+
def is_caffe2_backend(self):
43+
return self.backend == "caffe2"
44+
45+
def _get_tf_version(self):
46+
import tensorflow as tf
47+
return LooseVersion(tf.__version__)
48+
49+
def _get_backend_version(self):
50+
version = None
51+
if self.backend == "onnxruntime":
52+
import onnxruntime as ort
53+
version = ort.__version__
54+
elif self.backend == "caffe2":
55+
# TODO: get caffe2 version
56+
pass
57+
58+
if version:
59+
version = LooseVersion(version)
60+
return version
61+
62+
def __str__(self):
63+
return "\n\t".join(["TestConfig:",
64+
"platform={}".format(self.platform),
65+
"tf_version={}".format(self.tf_version),
66+
"opset={}".format(self.opset),
67+
"target={}".format(self.target),
68+
"backend={}".format(self.backend),
69+
"backend_version={}".format(self.backend_version),
70+
"is_debug_mode={}".format(self.is_debug_mode)])
71+
72+
@staticmethod
73+
def load():
74+
config = TestConfig()
75+
# if not launched by pytest, parse console arguments to override config
76+
if "pytest" not in sys.argv[0]:
77+
parser = argparse.ArgumentParser()
78+
parser.add_argument('--backend', default=config.backend,
79+
choices=["caffe2", "onnxmsrtnext", "onnxruntime"],
80+
help="backend to test against")
81+
parser.add_argument('--opset', type=int, default=config.opset, help="opset to test against")
82+
parser.add_argument("--target", default=",".join(config.target), choices=POSSIBLE_TARGETS,
83+
help="target platform")
84+
parser.add_argument("--debug", help="output debugging information", action="store_true")
85+
parser.add_argument('unittest_args', nargs='*')
86+
87+
args = parser.parse_args()
88+
config.backend = args.backend
89+
config.opset = args.opset
90+
config.target = args.target.split(',')
91+
config.is_debug_mode = args.debug
92+
93+
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
94+
sys.argv[1:] = args.unittest_args
95+
96+
return config
97+
98+
99+
# need to load config BEFORE main is executed when launched from script
100+
# otherwise, it will be too late for test filters to take effect
101+
_config = TestConfig.load()
102+
103+
104+
def get_test_config():
105+
global _config
106+
return _config
107+
108+
109+
def unittest_main():
110+
print(get_test_config())
111+
unittest.main()
112+
113+
114+
def _append_message(reason, message):
115+
if message:
116+
reason = reason + ": " + message
117+
return reason
118+
119+
120+
def check_tf_min_version(min_required_version, message=""):
121+
""" Skip if tf_version < min_required_version """
122+
config = get_test_config()
123+
reason = _append_message("conversion requires tf >= {}".format(min_required_version), message)
124+
return unittest.skipIf(config.tf_version < LooseVersion(min_required_version), reason)
125+
126+
127+
def skip_tf_versions(excluded_versions, message=""):
128+
""" Skip if tf_version SEMANTICALLY matches any of excluded_versions. """
129+
config = get_test_config()
130+
condition = False
131+
reason = _append_message("conversion excludes tf {}".format(excluded_versions), message)
132+
133+
current_tokens = str(config.tf_version).split('.')
134+
for excluded_version in excluded_versions:
135+
exclude_tokens = excluded_version.split('.')
136+
# assume len(exclude_tokens) <= len(current_tokens)
137+
for i, exclude in enumerate(exclude_tokens):
138+
if not current_tokens[i] == exclude:
139+
break
140+
condition = True
141+
142+
return unittest.skipIf(condition, reason)
143+
144+
145+
def check_opset_min_version(min_required_version, message=""):
146+
""" Skip if opset < min_required_version """
147+
config = get_test_config()
148+
reason = _append_message("conversion requires opset >= {}".format(min_required_version), message)
149+
return unittest.skipIf(config.opset < min_required_version, reason)
150+
151+
152+
def check_target(required_target, message=""):
153+
""" Skip if required_target is NOT specified """
154+
config = get_test_config()
155+
reason = _append_message("conversion requires target {} specified".format(required_target), message)
156+
return unittest.skipIf(required_target not in config.target, reason)
157+
158+
159+
def skip_onnxruntime_backend(message=""):
160+
""" Skip if backend is onnxruntime """
161+
config = get_test_config()
162+
reason = _append_message("not supported by onnxruntime", message)
163+
return unittest.skipIf(config.is_onnxruntime_backend, reason)
164+
165+
166+
def skip_caffe2_backend(message=""):
167+
""" Skip if backend is caffe2 """
168+
config = get_test_config()
169+
reason = _append_message("not supported by caffe2", message)
170+
return unittest.skipIf(config.is_caffe2_backend, reason)
171+
172+
173+
def check_onnxruntime_incompatibility(op):
174+
""" Skip if backend is onnxruntime AND op is NOT supported in current opset """
175+
config = get_test_config()
176+
177+
if not config.is_onnxruntime_backend:
178+
return unittest.skipIf(False, None)
179+
180+
support_since = {
181+
"Abs": 6, # Abs-1
182+
"Add": 7, # Add-1, Add-6
183+
"AveragePool": 7, # AveragePool-1
184+
"Div": 7, # Div-1, Div-6
185+
"Elu": 6, # Elu-1
186+
"Exp": 6, # Exp-1
187+
"Greater": 7, # Greater-1
188+
"Less": 7, # Less-1
189+
"Log": 6, # Log-1
190+
"Max": 6, # Max-1
191+
"Min": 6, # Min-1
192+
"Mul": 7, # Mul-1, Mul-6
193+
"Neg": 6, # Neg-1
194+
"Pow": 7, # Pow-1
195+
"Reciprocal": 6, # Reciprocal-1
196+
"Relu": 6, # Relu-1
197+
"Sqrt": 6, # Sqrt-1
198+
"Sub": 7, # Sub-1, Sub-6
199+
"Tanh": 6, # Tanh-1
200+
}
201+
202+
if op not in support_since or config.opset >= support_since[op]:
203+
return unittest.skipIf(False, None)
204+
205+
reason = "{} is not supported by onnxruntime before opset {}".format(op, support_since[op])
206+
return unittest.skipIf(True, reason)

0 commit comments

Comments
 (0)