Skip to content

Commit 2f75c13

Browse files
committed
merge master
2 parents 1365fb1 + 5e98390 commit 2f75c13

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+754
-419
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ optional arguments:
161161
--tests TESTS tests to run
162162
--backend BACKEND backend to use
163163
--config yaml config file
164-
--verbose verbose output
164+
--verbose verbose output, option is additive
165165
--opset OPSET target opset to use
166166
--perf csv-file capture performance numbers or tensorflow and onnx runtime
167167
--debug dump generated graph with shape info
@@ -176,6 +176,13 @@ You call it for example with:
176176
python tests/run_pretrained_models.py --backend onnxruntime --config tests/run_pretrained_models.yaml --perf perf.csv
177177
```
178178

179+
### <a name="save_pretrained_model"></a>Tool to save pre-trained model
180+
181+
We provide an [utility](tools/save_pretrained_model.py) to save pre-trained model along with its config.
182+
Put `save_pretrained_model(sess, outputs, feed_inputs, save_dir, model_name)` in your last testing epoch and the pre-trained model and config will be saved under `save_dir/to_onnx`.
183+
Please refer to the example in [tools/save_pretrained_model.py](tools/save_pretrained_model.py) for more information.
184+
Note the minimum required Tensorflow version is r1.6.
185+
179186
# Using the Python API
180187
## TensorFlow to ONNX conversion
181188
In some cases it will be useful to convert the models from TensorFlow to ONNX from a python script. You can use the following API:
@@ -192,7 +199,7 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph,
192199
Args:
193200
tf_graph: tensorflow graph
194201
continue_on_error: if an op can't be processed (aka there is no mapping), continue
195-
verbose: print summary stats
202+
verbose: print summary stats (deprecated)
196203
target: list of workarounds applied to help certain platforms
197204
opset: the opset to be used (int, default is latest)
198205
custom_op_handlers: dictionary of custom ops handlers

tests/backend_test_base.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,11 @@
2626
class Tf2OnnxBackendTestBase(unittest.TestCase):
2727
def setUp(self):
2828
self.config = get_test_config()
29-
self.maxDiff = None
3029
tf.reset_default_graph()
3130
# reset name generation on every test
3231
utils.INTERNAL_NAME = 1
3332
np.random.seed(1) # Make it reproducible.
34-
35-
self.log = logging.getLogger("tf2onnx.unitest." + str(type(self)))
36-
if self.config.is_debug_mode:
37-
self.log.setLevel(logging.DEBUG)
38-
else:
39-
# suppress log info of tensorflow so that result of test can be seen much easier
40-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
41-
tf.logging.set_verbosity(tf.logging.WARN)
42-
self.log.setLevel(logging.INFO)
33+
self.logger = logging.getLogger(self.__class__.__name__)
4334

4435
def tearDown(self):
4536
if not self.config.is_debug_mode:
@@ -125,15 +116,15 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
125116
os.makedirs(self.test_data_directory)
126117
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_original.pb")
127118
utils.save_protobuf(model_path, sess.graph_def)
128-
self.log.debug("created file %s", model_path)
119+
self.logger.debug("created file %s", model_path)
129120

130121
graph_def = tf_optimize(input_names_with_port, output_names_with_port,
131122
sess.graph_def, constant_fold)
132123

133124
if self.config.is_debug_mode:
134125
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
135126
utils.save_protobuf(model_path, graph_def)
136-
self.log.debug("created file %s", model_path)
127+
self.logger.debug("created file %s", model_path)
137128

138129
tf.reset_default_graph()
139130
tf.import_graph_def(graph_def, name='')
@@ -162,5 +153,5 @@ def save_onnx_model(self, model_proto, feed_dict, postfix=""):
162153
model_proto, include_test_data=self.config.is_debug_mode,
163154
as_text=self.config.is_debug_mode)
164155

165-
self.log.debug("create model file: %s", target_path)
156+
self.logger.debug("create model file: %s", target_path)
166157
return target_path

tests/common.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
from distutils.version import LooseVersion
1313
from parameterized import parameterized
14-
from tf2onnx import constants, utils
14+
from tf2onnx import constants, logging, utils
1515

1616
__all__ = ["TestConfig", "get_test_config", "unittest_main", "check_onnxruntime_backend",
17-
"check_tf_min_version", "skip_tf_versions", "check_onnxruntime_min_version",
17+
"check_tf_min_version", "check_tf_max_version", "skip_tf_versions", "check_onnxruntime_min_version",
1818
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
1919
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
2020
"group_nodes_by_type", "test_ms_domain", "check_node_domain"]
@@ -30,7 +30,7 @@ def __init__(self):
3030
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',')
3131
self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
3232
self.backend_version = self._get_backend_version()
33-
self.is_debug_mode = False
33+
self.log_level = logging.WARNING
3434
self.temp_dir = utils.get_temp_directory()
3535

3636
@property
@@ -45,6 +45,10 @@ def is_onnxruntime_backend(self):
4545
def is_caffe2_backend(self):
4646
return self.backend == "caffe2"
4747

48+
@property
49+
def is_debug_mode(self):
50+
return utils.is_debug_mode()
51+
4852
def _get_tf_version(self):
4953
import tensorflow as tf
5054
return LooseVersion(tf.__version__)
@@ -85,15 +89,19 @@ def load():
8589
parser.add_argument("--opset", type=int, default=config.opset, help="opset to test against")
8690
parser.add_argument("--target", default=",".join(config.target), choices=constants.POSSIBLE_TARGETS,
8791
help="target platform")
92+
parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
8893
parser.add_argument("--debug", help="output debugging information", action="store_true")
8994
parser.add_argument("--temp_dir", help="temp dir")
9095
parser.add_argument("unittest_args", nargs='*')
9196

9297
args = parser.parse_args()
98+
if args.debug:
99+
utils.set_debug_mode(True)
100+
93101
config.backend = args.backend
94102
config.opset = args.opset
95103
config.target = args.target.split(',')
96-
config.is_debug_mode = args.debug
104+
config.log_level = logging.get_verbosity_level(args.verbose, config.log_level)
97105
if args.temp_dir:
98106
config.temp_dir = args.temp_dir
99107

@@ -114,7 +122,10 @@ def get_test_config():
114122

115123

116124
def unittest_main():
117-
print(get_test_config())
125+
config = get_test_config()
126+
logging.basicConfig(level=config.log_level)
127+
with logging.set_scope_level(logging.INFO) as logger:
128+
logger.info(config)
118129
unittest.main()
119130

120131

@@ -124,6 +135,13 @@ def _append_message(reason, message):
124135
return reason
125136

126137

138+
def check_tf_max_version(max_accepted_version, message=""):
139+
""" Skip if tf_version > max_required_version """
140+
config = get_test_config()
141+
reason = _append_message("conversion requires tf <= {}".format(max_accepted_version), message)
142+
return unittest.skipIf(config.tf_version > LooseVersion(max_accepted_version), reason)
143+
144+
127145
def check_tf_min_version(min_required_version, message=""):
128146
""" Skip if tf_version < min_required_version """
129147
config = get_test_config()

tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
""" print pytest config."""
55

66
from common import get_test_config
7+
from tf2onnx import logging
78

89

910
def pytest_configure():
10-
print(get_test_config())
11+
config = get_test_config()
12+
logging.basicConfig(level=config.log_level)
13+
with logging.set_scope_level(logging.INFO) as logger:
14+
logger.info(config)

tests/run_pretrained_models.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import time
1616
import traceback
1717
import zipfile
18-
import logging
1918

2019
import PIL.Image
2120
import numpy as np
@@ -28,15 +27,12 @@
2827
import yaml
2928

3029
import tf2onnx
31-
from tf2onnx import loader
32-
from tf2onnx import utils
33-
from tf2onnx import optimizer
30+
from tf2onnx import loader, logging, optimizer, utils
3431
from tf2onnx.tfonnx import process_tf_graph
3532

3633
# pylint: disable=broad-except,logging-not-lazy,unused-argument,unnecessary-lambda
3734

38-
logging.basicConfig(level=logging.INFO)
39-
log = logging.getLogger("tf2onnx")
35+
logger = logging.getLogger("run_pretrained")
4036

4137
TEMP_DIR = os.path.join(utils.get_temp_directory(), "run_pretrained")
4238
PERFITER = 1000
@@ -157,7 +153,7 @@ def run_tensorflow(self, sess, inputs):
157153

158154
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None):
159155
"""Convert graph to tensorflow."""
160-
return process_tf_graph(tf_graph, continue_on_error=False, verbose=True, opset=opset,
156+
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset,
161157
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
162158
input_names=input_names, output_names=self.output_names)
163159

@@ -207,7 +203,7 @@ def create_onnx_file(name, model_proto, inputs, outdir):
207203
utils.save_protobuf(model_path, model_proto)
208204
print("\tcreated", model_path)
209205

210-
def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=None, extra_opset=None,
206+
def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_opset=None,
211207
perf=None, fold_const=None):
212208
"""Run complete test against backend."""
213209
print(name)
@@ -222,18 +218,20 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
222218
dir_name = os.path.dirname(self.local)
223219
print("\tdownloaded", model_path)
224220

225-
inputs = list(self.input_names.keys())
221+
input_names = list(self.input_names.keys())
226222
outputs = self.output_names
227223
if self.model_type in ["checkpoint"]:
228-
graph_def, inputs, outputs = loader.from_checkpoint(model_path, inputs, outputs)
224+
graph_def, input_names, outputs = loader.from_checkpoint(model_path, input_names, outputs)
229225
elif self.model_type in ["saved_model"]:
230-
graph_def, inputs, outputs = loader.from_saved_model(model_path, inputs, outputs)
226+
graph_def, input_names, outputs = loader.from_saved_model(model_path, input_names, outputs)
231227
else:
232-
graph_def, inputs, outputs = loader.from_graphdef(model_path, inputs, outputs)
228+
graph_def, input_names, outputs = loader.from_graphdef(model_path, input_names, outputs)
233229

234230
# create the input data
235231
inputs = {}
236232
for k, v in self.input_names.items():
233+
if k not in input_names:
234+
continue
237235
if isinstance(v, six.text_type) and v.startswith("np."):
238236
inputs[k] = eval(v) # pylint: disable=eval-used
239237
else:
@@ -243,7 +241,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
243241
inputs[k] = v
244242

245243
graph_def = tf2onnx.tfonnx.tf_optimize(inputs.keys(), self.output_names, graph_def, fold_const)
246-
if debug:
244+
if utils.is_debug_mode():
247245
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
248246
shape_override = {}
249247
g = tf.import_graph_def(graph_def, name='')
@@ -255,7 +253,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
255253
dtype = tf.as_dtype(t.dtype).name
256254
v = inputs[k]
257255
if dtype != v.dtype:
258-
log.warning("input dtype doesn't match tensorflow's")
256+
logger.warning("input dtype doesn't match tensorflow's")
259257
inputs[k] = np.array(v, dtype=dtype)
260258
if self.force_input_shape:
261259
for k, v in inputs.items():
@@ -273,13 +271,13 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
273271
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
274272
shape_override=shape_override, input_names=inputs.keys())
275273
model_proto = onnx_graph.make_model("converted from tf2onnx")
276-
new_model_proto = optimizer.optimize_graph(onnx_graph, debug=debug).make_model("optimized")
274+
new_model_proto = optimizer.optimize_graph(onnx_graph).make_model("optimized")
277275
if new_model_proto:
278276
model_proto = new_model_proto
279277
else:
280278
print("\tNON-CRITICAL, optimizers are not applied successfully")
281279
print("\tto_onnx", "OK")
282-
if debug:
280+
if utils.is_debug_mode():
283281
onnx_graph.dump_graph()
284282
if onnx_file:
285283
self.create_onnx_file(name, model_proto, inputs, onnx_file)
@@ -312,10 +310,12 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
312310
print("\tResults: OK")
313311
return True
314312
except Exception as ex:
315-
print("\tResults: ", ex)
313+
tb = traceback.format_exc()
314+
print("\tResults", ex, tb)
316315

317316
except Exception as ex:
318-
print("\trun_onnx", "FAIL", ex)
317+
tb = traceback.format_exc()
318+
print("\trun_onnx", "FAIL", ex, tb)
319319

320320
return False
321321

@@ -329,11 +329,11 @@ def get_args():
329329
parser.add_argument("--target", default="", help="target platform")
330330
parser.add_argument("--backend", default="onnxruntime",
331331
choices=["caffe2", "onnxmsrtnext", "onnxruntime"], help="backend to use")
332-
parser.add_argument("--verbose", help="verbose output", action="store_true")
333332
parser.add_argument("--opset", type=int, default=None, help="opset to use")
334333
parser.add_argument("--extra_opset", default=None,
335334
help="extra opset with format like domain:version, e.g. com.microsoft:1")
336-
parser.add_argument("--debug", help="debug vlog", action="store_true")
335+
parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
336+
parser.add_argument("--debug", help="debug mode", action="store_true")
337337
parser.add_argument("--list", help="list tests", action="store_true")
338338
parser.add_argument("--onnx-file", help="create onnx file in directory")
339339
parser.add_argument("--perf", help="capture performance numbers")
@@ -370,11 +370,11 @@ def tests_from_yaml(fname):
370370

371371

372372
def main():
373-
# suppress log info of tensorflow so that result of test can be seen much easier
374-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
375-
tf.logging.set_verbosity(tf.logging.WARN)
376-
377373
args = get_args()
374+
logging.basicConfig(level=logging.get_verbosity_level(args.verbose))
375+
if args.debug:
376+
utils.set_debug_mode(True)
377+
378378
Test.cache_dir = args.cache
379379
Test.target = args.target
380380
tests = tests_from_yaml(args.config)
@@ -394,14 +394,15 @@ def main():
394394
continue
395395
count += 1
396396
try:
397-
ret = t.run_test(test, backend=args.backend, debug=args.debug, onnx_file=args.onnx_file,
397+
ret = t.run_test(test, backend=args.backend, onnx_file=args.onnx_file,
398398
opset=args.opset, extra_opset=args.extra_opset, perf=args.perf,
399399
fold_const=args.fold_const)
400400
except Exception as ex:
401401
ret = None
402-
print(ex)
402+
tb = traceback.format_exc()
403+
print(ex, tb)
403404
finally:
404-
if not args.debug:
405+
if not utils.is_debug_mode():
405406
utils.delete_directory(TEMP_DIR)
406407
if not ret:
407408
failed += 1

tests/run_pretrained_models.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,30 @@ saved_model_with_redundant_inputs:
3232
model_type: saved_model
3333
input_get: get_ramp
3434
inputs:
35+
"X:0": [1, 10]
3536
"Placeholder:0": [1, 10]
3637
outputs:
3738
- Add:0
3839

40+
graphdef_with_redundant_inputs:
41+
model: tests/models/regression/graphdef/frozen.pb
42+
input_get: get_ramp
43+
inputs:
44+
"X:0": [1, 10]
45+
"Placeholder:0": [1, 10]
46+
outputs:
47+
- Add:0
48+
49+
checkpoint_with_redundant_inputs:
50+
model: tests/models/regression/checkpoint/model.meta
51+
model_type: checkpoint
52+
input_get: get_ramp
53+
inputs:
54+
"X:0": [1]
55+
"Placeholder:0": [1, 10]
56+
outputs:
57+
- pred:0
58+
3959
benchtf-fc:
4060
model: tests/models/fc-layers/frozen.pb
4161
input_get: get_ramp

0 commit comments

Comments
 (0)