Skip to content

Commit dc456fc

Browse files
committed
unify temp test data saving logic
1 parent d6f1411 commit dc456fc

File tree

6 files changed

+73
-50
lines changed

6 files changed

+73
-50
lines changed

tests/backend_test_base.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def setUp(self):
4040
tf.logging.set_verbosity(tf.logging.WARN)
4141
self.log.setLevel(logging.INFO)
4242

43+
def tearDown(self):
44+
if not self.config.is_debug_mode:
45+
utils.delete_directory(self.test_data_directory)
46+
47+
@property
48+
def test_data_directory(self):
49+
return os.path.join(self.config.temp_dir, self._testMethodName)
50+
4351
@staticmethod
4452
def assertAllClose(expected, actual, **kwargs):
4553
np.testing.assert_allclose(expected, actual, **kwargs)
@@ -72,6 +80,7 @@ def run_onnxruntime(self, model_path, inputs, output_names):
7280
def _run_backend(self, g, outputs, input_dict):
7381
model_proto = g.make_model("test")
7482
model_path = self.save_onnx_model(model_proto, input_dict)
83+
7584
if self.config.backend == "onnxmsrtnext":
7685
y = self.run_onnxmsrtnext(model_path, input_dict, outputs)
7786
elif self.config.backend == "onnxruntime":
@@ -93,8 +102,6 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
93102
onnx_feed_dict = feed_dict
94103

95104
graph_def = None
96-
save_dir = os.path.join(self.config.temp_path, self._testMethodName)
97-
98105
if convert_var_to_const:
99106
with tf.Session() as sess:
100107
variables_lib.global_variables_initializer().run()
@@ -113,20 +120,18 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
113120
expected = sess.run(output_dict, feed_dict=feed_dict)
114121

115122
if self.config.is_debug_mode:
116-
if not os.path.exists(save_dir):
117-
os.makedirs(save_dir)
118-
model_path = os.path.join(save_dir, self._testMethodName + "_original.pb")
119-
with open(model_path, "wb") as f:
120-
f.write(sess.graph_def.SerializeToString())
123+
if not os.path.exists(self.test_data_directory):
124+
os.makedirs(self.test_data_directory)
125+
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_original.pb")
126+
utils.save_protobuf(model_path, sess.graph_def)
121127
self.log.debug("created file %s", model_path)
122128

123129
graph_def = tf_optimize(input_names_with_port, output_names_with_port,
124130
sess.graph_def, constant_fold)
125131

126132
if self.config.is_debug_mode and constant_fold:
127-
model_path = os.path.join(save_dir, self._testMethodName + "_after_tf_optimize.pb")
128-
with open(model_path, "wb") as f:
129-
f.write(graph_def.SerializeToString())
133+
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
134+
utils.save_protobuf(model_path, graph_def)
130135
self.log.debug("created file %s", model_path)
131136

132137
tf.reset_default_graph()
@@ -146,9 +151,8 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
146151
self.assertEqual(expected_val.shape, actual_val.shape)
147152

148153
def save_onnx_model(self, model_proto, feed_dict, postfix=""):
149-
save_path = os.path.join(self.config.temp_path, self._testMethodName)
150-
target_path = utils.save_onnx_model(save_path, self._testMethodName + postfix, feed_dict, model_proto,
151-
include_test_data=self.config.is_debug_mode,
154+
target_path = utils.save_onnx_model(self.test_data_directory, self._testMethodName + postfix, feed_dict,
155+
model_proto, include_test_data=self.config.is_debug_mode,
152156
as_text=self.config.is_debug_mode)
153157

154158
self.log.debug("create model file: %s", target_path)

tests/common.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
import argparse
77
import os
88
import sys
9-
import tempfile
109
import unittest
1110

1211
from distutils.version import LooseVersion
12+
from tf2onnx import utils
1313
from tf2onnx.tfonnx import DEFAULT_TARGET, POSSIBLE_TARGETS
1414

1515
__all__ = ["TestConfig", "get_test_config", "unittest_main",
16-
"check_tf_min_version", "check_opset_min_version", "check_target", "skip_onnxruntime_backend",
17-
"skip_caffe2_backend", "check_onnxruntime_incompatibility"]
16+
"check_tf_min_version", "skip_tf_versions",
17+
"check_opset_min_version", "check_target", "skip_onnxruntime_backend", "skip_caffe2_backend",
18+
"check_onnxruntime_incompatibility"]
1819

1920

2021
# pylint: disable=missing-docstring
@@ -28,7 +29,7 @@ def __init__(self):
2829
self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
2930
self.backend_version = self._get_backend_version()
3031
self.is_debug_mode = False
31-
self.temp_path = tempfile.mkdtemp()
32+
self.temp_dir = utils.get_temp_directory()
3233

3334
@property
3435
def is_mac(self):
@@ -67,28 +68,32 @@ def __str__(self):
6768
"target={}".format(self.target),
6869
"backend={}".format(self.backend),
6970
"backend_version={}".format(self.backend_version),
70-
"is_debug_mode={}".format(self.is_debug_mode)])
71+
"is_debug_mode={}".format(self.is_debug_mode),
72+
"temp_dir={}".format(self.temp_dir)])
7173

7274
@staticmethod
7375
def load():
7476
config = TestConfig()
7577
# if not launched by pytest, parse console arguments to override config
7678
if "pytest" not in sys.argv[0]:
7779
parser = argparse.ArgumentParser()
78-
parser.add_argument('--backend', default=config.backend,
80+
parser.add_argument("--backend", default=config.backend,
7981
choices=["caffe2", "onnxmsrtnext", "onnxruntime"],
8082
help="backend to test against")
81-
parser.add_argument('--opset', type=int, default=config.opset, help="opset to test against")
83+
parser.add_argument("--opset", type=int, default=config.opset, help="opset to test against")
8284
parser.add_argument("--target", default=",".join(config.target), choices=POSSIBLE_TARGETS,
8385
help="target platform")
8486
parser.add_argument("--debug", help="output debugging information", action="store_true")
85-
parser.add_argument('unittest_args', nargs='*')
87+
parser.add_argument("--temp_dir", help="temp dir")
88+
parser.add_argument("unittest_args", nargs='*')
8689

8790
args = parser.parse_args()
8891
config.backend = args.backend
8992
config.opset = args.opset
9093
config.target = args.target.split(',')
9194
config.is_debug_mode = args.debug
95+
if args.temp_dir:
96+
config.temp_dir = args.temp_dir
9297

9398
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
9499
sys.argv[1:] = args.unittest_args

tests/run_pretrained_models.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010

1111
import argparse
1212
import os
13-
import shutil
1413
import sys
1514
import tarfile
16-
import tempfile
1715
import time
1816
import traceback
1917
import zipfile
@@ -37,7 +35,7 @@
3735

3836
# pylint: disable=broad-except,logging-not-lazy,unused-argument,unnecessary-lambda
3937

40-
TMPPATH = tempfile.mkdtemp()
38+
TEMP_DIR = os.path.join(utils.get_temp_directory(), "run_pretrained")
4139
PERFITER = 1000
4240

4341

@@ -192,7 +190,7 @@ def run_caffe2(self, name, model_proto, inputs):
192190
def run_onnxmsrtnext(self, name, model_proto, inputs):
193191
"""Run test against msrt-next backend."""
194192
import lotus
195-
model_path = utils.save_onnx_model(TMPPATH, name, inputs, model_proto)
193+
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto)
196194
m = lotus.InferenceSession(model_path)
197195
results = m.run(self.output_names, inputs)
198196
if self.perf:
@@ -205,8 +203,8 @@ def run_onnxmsrtnext(self, name, model_proto, inputs):
205203
def run_onnxruntime(self, name, model_proto, inputs):
206204
"""Run test against msrt-next backend."""
207205
import onnxruntime as rt
208-
model_path = utils.save_onnx_model(TMPPATH, name, inputs, model_proto, include_test_data=True)
209-
utils.save_onnx_model(TMPPATH, name, inputs, model_proto, include_test_data=False, as_text=True)
206+
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=True)
207+
utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=False, as_text=True)
210208
print("\t\t" + model_path)
211209
m = rt.InferenceSession(model_path)
212210
results = m.run(self.output_names, inputs)
@@ -221,8 +219,7 @@ def run_onnxruntime(self, name, model_proto, inputs):
221219
def create_onnx_file(name, model_proto, inputs, outdir):
222220
os.makedirs(outdir, exist_ok=True)
223221
model_path = os.path.join(outdir, name + ".onnx")
224-
with open(model_path, "wb") as f:
225-
f.write(model_proto.SerializeToString())
222+
utils.save_protobuf(model_path, model_proto)
226223
print("\tcreated", model_path)
227224

228225
def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=None, perf=None, fold_const=None):
@@ -440,8 +437,8 @@ def main():
440437
ret = None
441438
print(ex)
442439
finally:
443-
if os.path.exists(TMPPATH) and not args.debug:
444-
shutil.rmtree(TMPPATH)
440+
if not args.debug:
441+
utils.delete_directory(TEMP_DIR)
445442
if not ret:
446443
failed += 1
447444

tests/test_backend.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,8 @@ def test_relu(self):
722722
@check_onnxruntime_incompatibility("Mul")
723723
def test_leaky_relu(self):
724724
for alpha in [0.1, -0.1, 1.0, -1.0, 10.0, -10.0]:
725-
x_val = 1000*np.random.random_sample([1000, 100]).astype(np.float32)
726-
x = tf.placeholder(tf.float32, [None]*x_val.ndim, name=_TFINPUT)
725+
x_val = 1000 * np.random.random_sample([1000, 100]).astype(np.float32)
726+
x = tf.placeholder(tf.float32, [None] * x_val.ndim, name=_TFINPUT)
727727
x_ = tf.nn.leaky_relu(x, alpha)
728728
_ = tf.identity(x_, name=_TFOUTPUT)
729729
self._run_test_case([_OUTPUT], {_INPUT: x_val})
@@ -1199,7 +1199,7 @@ def test_strided_slice6(self):
11991199

12001200
@skip_caffe2_backend("multiple dims not supported")
12011201
def test_strided_slice7(self):
1202-
x_val = np.arange(5*6).astype("float32").reshape(5, 6)
1202+
x_val = np.arange(5 * 6).astype("float32").reshape(5, 6)
12031203

12041204
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
12051205
x_ = tf.strided_slice(x, [0, 1], [3, 4], [1, 1], begin_mask=2)
@@ -1218,7 +1218,6 @@ def test_strided_slice7(self):
12181218
_ = tf.identity(x_, name=_TFOUTPUT)
12191219
self._run_test_case([_OUTPUT], {_INPUT: x_val})
12201220

1221-
12221221
@skip_caffe2_backend("fails with schema error")
12231222
@check_opset_min_version(7, "batchnorm")
12241223
def test_batchnorm(self):

tf2onnx/convert.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
from onnx import helper
1515
import tensorflow as tf
1616

17-
import tf2onnx.utils
17+
import tf2onnx
18+
from tf2onnx import utils
1819
from tf2onnx.graph import GraphUtil
1920
from tf2onnx.tfonnx import process_tf_graph, tf_optimize, DEFAULT_TARGET, POSSIBLE_TARGETS
2021

2122
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
2223

24+
2325
# pylint: disable=unused-argument
2426

2527

@@ -45,7 +47,7 @@ def get_args():
4547

4648
args.shape_override = None
4749
if args.inputs:
48-
args.inputs, args.shape_override = tf2onnx.utils.split_nodename_and_shape(args.inputs)
50+
args.inputs, args.shape_override = utils.split_nodename_and_shape(args.inputs)
4951
if args.outputs:
5052
args.outputs = args.outputs.split(",")
5153
if args.inputs_as_nchw:
@@ -64,14 +66,14 @@ def default_custom_op_handler(ctx, node, name, args):
6466
def main():
6567
args = get_args()
6668

67-
opset = tf2onnx.utils.find_opset(args.opset)
69+
opset = utils.find_opset(args.opset)
6870
print("using tensorflow={}, onnx={}, opset={}, tfonnx={}/{}".format(
6971
tf.__version__, onnx.__version__, opset,
7072
tf2onnx.__version__, tf2onnx.version.git_version[:6]))
7173

7274
# override unknown dimensions from -1 to 1 (aka batchsize 1) since not every runtime does
7375
# support unknown dimensions.
74-
tf2onnx.utils.ONNX_UNKNOWN_DIMENSION = args.unknown_dim
76+
utils.ONNX_UNKNOWN_DIMENSION = args.unknown_dim
7577

7678
if args.custom_ops:
7779
# default custom ops for tensorflow-onnx are in the "tf" namespace
@@ -113,9 +115,8 @@ def main():
113115

114116
# write onnx graph
115117
if args.output:
116-
with open(args.output, "wb") as f:
117-
f.write(model_proto.SerializeToString())
118-
print("\nComplete successfully, the onnx model is generated at " + args.output)
118+
utils.save_protobuf(args.output, model_proto)
119+
print("\nComplete successfully, the onnx model is generated at " + args.output)
119120

120121

121122
if __name__ == "__main__":

tf2onnx/utils.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import os
1313
import re
14+
import shutil
15+
import tempfile
1416
import six
1517
import numpy as np
1618
import tensorflow as tf
@@ -95,6 +97,7 @@
9597
# Fake onnx op type which is used for Graph input.
9698
GRAPH_INPUT_TYPE = "NON_EXISTENT_ONNX_TYPE"
9799

100+
98101
def make_name(name):
99102
"""Make op name for inserted ops."""
100103
global INTERNAL_NAME
@@ -274,18 +277,13 @@ def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, incl
274277
t = numpy_helper.from_array(data)
275278
t.name = data_key
276279
data_full_path = os.path.join(data_path, "input_" + str(i) + ".pb")
277-
with open(data_full_path, 'wb') as f:
278-
f.write(t.SerializeToString())
280+
save_protobuf(data_full_path, t)
279281
i += 1
280282

281283
target_path = os.path.join(save_path, onnx_file_name + ".onnx")
282-
with open(target_path, "wb") as f:
283-
f.write(model_proto.SerializeToString())
284-
284+
save_protobuf(target_path, model_proto)
285285
if as_text:
286-
with open(target_path + ".pbtxt", "w") as f:
287-
f.write(text_format.MessageToString(model_proto))
288-
286+
save_protobuf(target_path + ".pbtxt", model_proto, as_text=True)
289287
return target_path
290288

291289

@@ -347,3 +345,22 @@ def tf_name_scope(name):
347345
def create_vague_shape_like(shape):
348346
make_sure(len(shape) >= 0, "rank should be >= 0")
349347
return [-1 for i in enumerate(shape)]
348+
349+
350+
def get_temp_directory():
351+
return os.environ.get("TF2ONNX_TEMP_DIRECTORY", tempfile.mkdtemp())
352+
353+
354+
def delete_directory(path):
355+
if os.path.exists(path):
356+
shutil.rmtree(path)
357+
358+
359+
def save_protobuf(path, message, as_text=False):
360+
os.makedirs(os.path.dirname(path), exist_ok=True)
361+
if as_text:
362+
with open(path, "w") as f:
363+
f.write(text_format.MessageToString(message))
364+
else:
365+
with open(path, "wb") as f:
366+
f.write(message.SerializeToString())

0 commit comments

Comments
 (0)