8
8
from __future__ import print_function
9
9
from __future__ import unicode_literals
10
10
11
- import argparse
12
11
import logging
13
12
import os
14
- import sys
15
- import tempfile
16
13
import unittest
17
14
18
15
import numpy as np
19
16
import tensorflow as tf
20
17
from tensorflow .python .ops import variables as variables_lib
18
+ from common import get_test_config
21
19
from tf2onnx import utils
22
- from tf2onnx .tfonnx import process_tf_graph , tf_optimize , DEFAULT_TARGET , POSSIBLE_TARGETS
20
+ from tf2onnx .tfonnx import process_tf_graph , tf_optimize
23
21
24
22
25
23
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
26
24
27
25
class Tf2OnnxBackendTestBase (unittest .TestCase ):
28
- # static variables
29
- TMPPATH = tempfile .mkdtemp ()
30
- BACKEND = os .environ .get ("TF2ONNX_TEST_BACKEND" , "onnxruntime" )
31
- OPSET = int (os .environ .get ("TF2ONNX_TEST_OPSET" , 7 ))
32
- TARGET = os .environ .get ("TF2ONNX_TEST_TARGET" , "" ).split ("," )
33
- DEBUG = None
34
-
35
- def debug_mode (self ):
36
- return type (self ).DEBUG
37
-
38
26
def setUp (self ):
27
+ self .config = get_test_config ()
39
28
self .maxDiff = None
40
29
tf .reset_default_graph ()
41
30
# reset name generation on every test
42
31
utils .INTERNAL_NAME = 1
43
32
np .random .seed (1 ) # Make it reproducible.
44
33
45
34
self .log = logging .getLogger ("tf2onnx.unitest." + str (type (self )))
46
- if self .debug_mode () :
35
+ if self .config . is_debug_mode :
47
36
self .log .setLevel (logging .DEBUG )
48
37
else :
49
38
# suppress log info of tensorflow so that result of test can be seen much easier
@@ -83,17 +72,17 @@ def run_onnxruntime(self, model_path, inputs, output_names):
83
72
def _run_backend (self , g , outputs , input_dict ):
84
73
model_proto = g .make_model ("test" )
85
74
model_path = self .save_onnx_model (model_proto , input_dict )
86
- if type ( self ). BACKEND == "onnxmsrtnext" :
75
+ if self . config . backend == "onnxmsrtnext" :
87
76
y = self .run_onnxmsrtnext (model_path , input_dict , outputs )
88
- elif type ( self ). BACKEND == "onnxruntime" :
77
+ elif self . config . backend == "onnxruntime" :
89
78
y = self .run_onnxruntime (model_path , input_dict , outputs )
90
- elif type ( self ). BACKEND == "caffe2" :
79
+ elif self . config . backend == "caffe2" :
91
80
y = self .run_onnxcaffe2 (model_proto , input_dict )
92
81
else :
93
82
raise ValueError ("unknown backend" )
94
83
return y
95
84
96
- def run_test_case (self , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-07 ,
85
+ def run_test_case (self , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-07 , atol = 0. ,
97
86
convert_var_to_const = True , constant_fold = True , check_value = True , check_shape = False ,
98
87
check_dtype = False , process_args = None , onnx_feed_dict = None ):
99
88
# optional - passed to process_tf_graph
@@ -104,7 +93,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
104
93
onnx_feed_dict = feed_dict
105
94
106
95
graph_def = None
107
- save_dir = os .path .join (type ( self ). TMPPATH , self ._testMethodName )
96
+ save_dir = os .path .join (self . config . temp_path , self ._testMethodName )
108
97
109
98
if convert_var_to_const :
110
99
with tf .Session () as sess :
@@ -123,7 +112,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
123
112
output_dict .append (sess .graph .get_tensor_by_name (out_name ))
124
113
expected = sess .run (output_dict , feed_dict = feed_dict )
125
114
126
- if self .debug_mode () :
115
+ if self .config . is_debug_mode :
127
116
if not os .path .exists (save_dir ):
128
117
os .makedirs (save_dir )
129
118
model_path = os .path .join (save_dir , self ._testMethodName + "_original.pb" )
@@ -134,7 +123,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
134
123
graph_def = tf_optimize (input_names_with_port , output_names_with_port ,
135
124
sess .graph_def , constant_fold )
136
125
137
- if self .debug_mode () and constant_fold :
126
+ if self .config . is_debug_mode and constant_fold :
138
127
model_path = os .path .join (save_dir , self ._testMethodName + "_after_tf_optimize.pb" )
139
128
with open (model_path , "wb" ) as f :
140
129
f .write (graph_def .SerializeToString ())
@@ -144,45 +133,23 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
144
133
tf .import_graph_def (graph_def , name = '' )
145
134
146
135
with tf .Session () as sess :
147
- g = process_tf_graph (sess .graph , opset = type ( self ). OPSET , output_names = output_names_with_port ,
148
- target = type ( self ). TARGET , ** process_args )
136
+ g = process_tf_graph (sess .graph , opset = self . config . opset , output_names = output_names_with_port ,
137
+ target = self . config . target , ** process_args )
149
138
actual = self ._run_backend (g , output_names_with_port , onnx_feed_dict )
150
139
151
140
for expected_val , actual_val in zip (expected , actual ):
152
141
if check_value :
153
- self .assertAllClose (expected_val , actual_val , rtol = rtol , atol = 0. )
142
+ self .assertAllClose (expected_val , actual_val , rtol = rtol , atol = atol )
154
143
if check_dtype :
155
144
self .assertEqual (expected_val .dtype , actual_val .dtype )
156
145
if check_shape :
157
146
self .assertEqual (expected_val .shape , actual_val .shape )
158
147
159
148
def save_onnx_model (self , model_proto , feed_dict , postfix = "" ):
160
- save_path = os .path .join (type ( self ). TMPPATH , self ._testMethodName )
149
+ save_path = os .path .join (self . config . temp_path , self ._testMethodName )
161
150
target_path = utils .save_onnx_model (save_path , self ._testMethodName + postfix , feed_dict , model_proto ,
162
- include_test_data = self .debug_mode (), as_text = self .debug_mode ())
151
+ include_test_data = self .config .is_debug_mode ,
152
+ as_text = self .config .is_debug_mode )
163
153
164
154
self .log .debug ("create model file: %s" , target_path )
165
155
return target_path
166
-
167
- @staticmethod
168
- def trigger (ut_class ):
169
- parser = argparse .ArgumentParser ()
170
- parser .add_argument ('--backend' , default = Tf2OnnxBackendTestBase .BACKEND ,
171
- choices = ["caffe2" , "onnxmsrtnext" , "onnxruntime" ],
172
- help = "backend to test against" )
173
- parser .add_argument ('--opset' , type = int , default = Tf2OnnxBackendTestBase .OPSET , help = "opset to test against" )
174
- parser .add_argument ("--target" , default = "," .join (DEFAULT_TARGET ), choices = POSSIBLE_TARGETS ,
175
- help = "target platform" )
176
- parser .add_argument ("--debug" , help = "output debugging information" , action = "store_true" )
177
- parser .add_argument ('unittest_args' , nargs = '*' )
178
-
179
- args = parser .parse_args ()
180
- print (args )
181
- ut_class .BACKEND = args .backend
182
- ut_class .OPSET = args .opset
183
- ut_class .DEBUG = args .debug
184
- ut_class .TARGET = args .target
185
-
186
- # Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
187
- sys .argv [1 :] = args .unittest_args
188
- unittest .main ()
0 commit comments