41
41
from tf2onnx import tf_loader , logging , optimizer , utils , tf_utils
42
42
from tf2onnx .tfonnx import process_tf_graph
43
43
from tf2onnx .tf_loader import tf_session , tf_reset_default_graph
44
+ from tf2onnx .graph import ExternalTensorStorage
44
45
45
46
logger = logging .getLogger ("run_pretrained" )
46
47
@@ -102,16 +103,20 @@ class Test(object):
102
103
cache_dir = None
103
104
target = []
104
105
105
- def __init__ (self , url , local , make_input , input_names , output_names ,
106
+ def __init__ (self , url , local , input_func , input_names , output_names ,
106
107
disabled = False , rtol = 0.01 , atol = 1e-6 ,
107
108
check_only_shape = False , model_type = "frozen" , force_input_shape = False ,
108
- skip_tensorflow = False , opset_constraints = None , tf_min_version = None , tag = None ):
109
+ skip_tensorflow = False , opset_constraints = None , tf_min_version = None , tag = None ,
110
+ skip_conversion = False , converted_model = None , signature_def = None , concrete_function = None ,
111
+ large_model = False , structured_outputs = None ):
109
112
self .url = url
110
- self .make_input = make_input
113
+ self .input_func = input_func
111
114
self .local = local
112
115
self .input_names = input_names
113
116
self .output_names = output_names
114
117
self .disabled = disabled
118
+ self .large_model = large_model
119
+ self .structured_outputs = structured_outputs # Needed to determine output order for tf_function
115
120
self .rtol = rtol
116
121
self .atol = atol
117
122
self .check_only_shape = check_only_shape
@@ -122,8 +127,18 @@ def __init__(self, url, local, make_input, input_names, output_names,
122
127
self .tag = tag
123
128
self .force_input_shape = force_input_shape
124
129
self .skip_tensorflow = skip_tensorflow
130
+ self .skip_conversion = skip_conversion
131
+ self .converted_model = converted_model
125
132
self .opset_constraints = opset_constraints
126
133
self .tf_min_version = tf_min_version
134
+ self .signatures = [signature_def ] if signature_def else None
135
+ self .concrete_function = concrete_function
136
+
137
+ def make_input (self , v ):
138
+ """Allows each input to specify its own function while defaulting to the input_get function"""
139
+ if isinstance (v , dict ):
140
+ return _INPUT_FUNC_MAPPING [v ["input_get" ]](v ["shape" ])
141
+ return self .input_func (v )
127
142
128
143
def download_model (self ):
129
144
"""Download model from url."""
@@ -149,7 +164,7 @@ def download_model(self):
149
164
if not os .path .exists (fpath ):
150
165
utils .get_url (url , fpath )
151
166
model_path = os .path .join (dir_name , self .local )
152
- if not os .path .exists (model_path ):
167
+ if not os .path .exists (model_path ) or self . local == "." :
153
168
if ftype == 'tgz' :
154
169
tar = tarfile .open (fpath )
155
170
tar .extractall (dir_name )
@@ -179,19 +194,23 @@ def run_tensorflow(self, sess, inputs):
179
194
for k , v in inputs .items ():
180
195
k = sess .graph .get_tensor_by_name (k )
181
196
feed_dict [k ] = v
197
+ logger .info ("Running TF" )
182
198
result = sess .run (self .output_names , feed_dict = feed_dict )
183
199
if self .perf :
200
+ logger .info ("Running TF perf" )
184
201
start = time .time ()
185
202
for _ in range (PERFITER ):
186
203
_ = sess .run (self .output_names , feed_dict = feed_dict )
187
204
self .tf_runtime = time .time () - start
188
205
return result
189
206
190
- def to_onnx (self , tf_graph , opset = None , extra_opset = None , shape_override = None , input_names = None ):
207
+ def to_onnx (self , tf_graph , opset = None , extra_opset = None , shape_override = None , input_names = None ,
208
+ const_node_values = None ):
191
209
"""Convert graph to tensorflow."""
192
210
return process_tf_graph (tf_graph , continue_on_error = False , opset = opset ,
193
211
extra_opset = extra_opset , target = Test .target , shape_override = shape_override ,
194
- input_names = input_names , output_names = self .output_names )
212
+ input_names = input_names , output_names = self .output_names ,
213
+ const_node_values = const_node_values )
195
214
196
215
def run_caffe2 (self , name , model_proto , inputs ):
197
216
"""Run test again caffe2 backend."""
@@ -205,11 +224,12 @@ def run_caffe2(self, name, model_proto, inputs):
205
224
self .onnx_runtime = time .time () - start
206
225
return results
207
226
208
- def run_onnxruntime (self , name , model_proto , inputs ):
227
+ def run_onnxruntime (self , name , model_proto , inputs , external_tensor_storage = None ):
209
228
"""Run test against onnxruntime backend."""
210
229
import onnxruntime as rt
211
230
model_path = utils .save_onnx_model (TEMP_DIR , name , inputs , model_proto , include_test_data = True ,
212
- as_text = utils .is_debug_mode ())
231
+ as_text = utils .is_debug_mode (),
232
+ external_tensor_storage = external_tensor_storage )
213
233
logger .info ("Model saved to %s" , model_path )
214
234
m = rt .InferenceSession (model_path )
215
235
results = m .run (self .output_names , inputs )
@@ -221,10 +241,14 @@ def run_onnxruntime(self, name, model_proto, inputs):
221
241
return results
222
242
223
243
@staticmethod
224
- def create_onnx_file (name , model_proto , inputs , outdir ):
244
+ def create_onnx_file (name , model_proto , inputs , outdir , external_tensor_storage = None ):
225
245
os .makedirs (outdir , exist_ok = True )
226
- model_path = os .path .join (outdir , name + ".onnx" )
227
- utils .save_protobuf (model_path , model_proto )
246
+ if external_tensor_storage is None :
247
+ model_path = os .path .join (outdir , name + ".onnx" )
248
+ utils .save_protobuf (model_path , model_proto )
249
+ else :
250
+ model_path = os .path .join (outdir , name + ".zip" )
251
+ utils .save_onnx_zip (model_path , model_proto , external_tensor_storage )
228
252
logger .info ("Created %s" , model_path )
229
253
230
254
def run_test (self , name , backend = "caffe2" , onnx_file = None , opset = None , extra_opset = None ,
@@ -236,7 +260,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
236
260
if self .url :
237
261
_ , dir_name = self .download_model ()
238
262
logger .info ("Downloaded to %s" , dir_name )
239
- model_path = os .path .join (dir_name , self .local )
263
+ model_path = os .path .join (dir_name , self .local ) if self . local != "." else dir_name
240
264
else :
241
265
model_path = self .local
242
266
@@ -246,13 +270,15 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
246
270
if self .model_type in ["checkpoint" ]:
247
271
graph_def , input_names , outputs = tf_loader .from_checkpoint (model_path , input_names , outputs )
248
272
elif self .model_type in ["saved_model" ]:
249
- try :
250
- res = tf_loader .from_saved_model (model_path , input_names , outputs , self .tag )
251
- except OSError :
252
- model_path = dir_name
253
- logger .info ("Load model(2) from %r" , model_path )
254
- res = tf_loader .from_saved_model (model_path , input_names , outputs , self .tag )
255
- graph_def , input_names , outputs = res [:3 ]
273
+ loaded = tf_loader .from_saved_model (model_path , input_names , outputs , self .tag , self .signatures ,
274
+ self .concrete_function , self .large_model ,
275
+ return_concrete_func = self .large_model )
276
+ if self .large_model :
277
+ # Must maintain ref to imported since concrete_func uses weak refs
278
+ # pylint: disable=unused-variable
279
+ graph_def , input_names , outputs , concrete_func , imported = loaded
280
+ else :
281
+ graph_def , input_names , outputs = loaded
256
282
elif self .model_type in ["keras" ]:
257
283
graph_def , input_names , outputs = tf_loader .from_keras (model_path , input_names , outputs )
258
284
else :
@@ -261,9 +287,34 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
261
287
if utils .is_debug_mode ():
262
288
utils .save_protobuf (os .path .join (TEMP_DIR , name + "_after_tf_optimize.pb" ), graph_def )
263
289
290
+ if self .large_model :
291
+ inputs = {}
292
+ for k in input_names :
293
+ v = self .input_names [k ]
294
+ inputs [k .split (":" )[0 ]] = tf .constant (self .make_input (v ))
295
+ tf_func = tf .function (concrete_func )
296
+ logger .info ("Running TF" )
297
+ tf_results_d = tf_func (** inputs )
298
+ if self .structured_outputs is None :
299
+ tf_results = list (tf_results_d .values ())
300
+ else :
301
+ tf_results = [tf_results_d [output ] for output in self .structured_outputs ]
302
+ if self .perf :
303
+ logger .info ("Running TF perf" )
304
+ start = time .time ()
305
+ for _ in range (PERFITER ):
306
+ _ = concrete_func (** inputs )
307
+ self .tf_runtime = time .time () - start
308
+ logger .info ("TensorFlow OK" )
309
+
264
310
inputs = {}
265
311
shape_override = {}
266
312
tf_reset_default_graph ()
313
+
314
+ from tf2onnx .tf_utils import compress_graph_def
315
+ const_node_values = None
316
+ if self .large_model :
317
+ const_node_values = compress_graph_def (graph_def )
267
318
g = tf .import_graph_def (graph_def , name = '' )
268
319
# with tf_session(config=tf.ConfigProto(allow_soft_placement=True), graph=g) as sess:
269
320
with tf_session (graph = g ) as sess :
@@ -288,30 +339,50 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
288
339
# run the model with tensorflow
289
340
if self .skip_tensorflow :
290
341
logger .info ("TensorFlow SKIPPED" )
291
- else :
342
+ elif not self . large_model :
292
343
tf_results = self .run_tensorflow (sess , inputs )
293
344
logger .info ("TensorFlow OK" )
294
345
295
346
model_proto = None
296
- try :
297
- # convert model to onnx
298
- onnx_graph = self .to_onnx (sess .graph , opset = opset , extra_opset = extra_opset ,
299
- shape_override = shape_override , input_names = inputs .keys ())
300
- onnx_graph = optimizer .optimize_graph (onnx_graph )
301
- model_proto = onnx_graph .make_model ("converted from tf2onnx" )
302
- logger .info ("To_ONNX, OK" )
303
- if onnx_file :
304
- self .create_onnx_file (name , model_proto , inputs , onnx_file )
305
- except Exception :
306
- logger .error ("To_ONNX FAIL" , exc_info = 1 )
307
- return False
347
+ if self .skip_conversion :
348
+ if self .large_model :
349
+ external_tensor_storage = ExternalTensorStorage ()
350
+ model_proto = utils .model_proto_from_zip (self .converted_model , external_tensor_storage )
351
+ else :
352
+ external_tensor_storage = None
353
+ model_proto = utils .model_proto_from_file (self .converted_model )
354
+ logger .info ("ONNX loaded from file" )
355
+ else :
356
+ try :
357
+ # convert model to onnx
358
+ onnx_graph = self .to_onnx (sess .graph , opset = opset , extra_opset = extra_opset ,
359
+ shape_override = shape_override , input_names = inputs .keys (),
360
+ const_node_values = const_node_values )
361
+ onnx_graph = optimizer .optimize_graph (onnx_graph )
362
+ print ("ONNX" , onnx_graph .dump_node_statistics ())
363
+ external_tensor_storage = ExternalTensorStorage () if self .large_model else None
364
+ model_proto = onnx_graph .make_model ("converted from tf2onnx" ,
365
+ external_tensor_storage = external_tensor_storage )
366
+ logger .info ("To_ONNX, OK" )
367
+ if onnx_file :
368
+ self .create_onnx_file (name , model_proto , inputs , onnx_file , external_tensor_storage )
369
+ if self .converted_model :
370
+ if self .large_model :
371
+ utils .save_onnx_zip (self .converted_model , model_proto , external_tensor_storage )
372
+ else :
373
+ utils .save_protobuf (self .converted_model , model_proto )
374
+ logger .info ("Created %s" , self .converted_model )
375
+
376
+ except Exception :
377
+ logger .error ("To_ONNX FAIL" , exc_info = 1 )
378
+ return False
308
379
309
380
try :
310
381
onnx_results = None
311
382
if backend == "caffe2" :
312
383
onnx_results = self .run_caffe2 (name , model_proto , inputs )
313
384
elif backend == "onnxruntime" :
314
- onnx_results = self .run_onnxruntime (name , model_proto , inputs )
385
+ onnx_results = self .run_onnxruntime (name , model_proto , inputs , external_tensor_storage )
315
386
else :
316
387
raise ValueError ("unknown backend" )
317
388
logger .info ("Run_ONNX OK" )
@@ -390,6 +461,7 @@ def get_args():
390
461
parser .add_argument ("--list" , help = "list tests" , action = "store_true" )
391
462
parser .add_argument ("--onnx-file" , help = "create onnx file in directory" )
392
463
parser .add_argument ("--perf" , help = "capture performance numbers" )
464
+ parser .add_argument ("--perfiter" , type = int , default = PERFITER , help = "number of inferences for perf testing" )
393
465
parser .add_argument ("--fold_const" , help = "enable tf constant_folding transformation before conversion" ,
394
466
action = "store_true" )
395
467
parser .add_argument ("--include-disabled" , help = "include disabled tests" , action = "store_true" )
@@ -447,8 +519,9 @@ def load_tests_from_yaml(path):
447
519
opset_constraints .append (c )
448
520
449
521
kwargs = {}
450
- for kw in ["rtol" , "atol" , "disabled" , "check_only_shape" , "model_type" ,
451
- "skip_tensorflow" , "force_input_shape" , "tf_min_version" , "tag" ]:
522
+ for kw in ["rtol" , "atol" , "disabled" , "check_only_shape" , "model_type" , "concrete_function" ,
523
+ "skip_tensorflow" , "force_input_shape" , "tf_min_version" , "tag" , "skip_conversion" ,
524
+ "converted_model" , "signature_def" , "large_model" , "structured_outputs" ]:
452
525
if settings .get (kw ) is not None :
453
526
kwargs [kw ] = settings [kw ]
454
527
@@ -459,6 +532,7 @@ def load_tests_from_yaml(path):
459
532
460
533
461
534
def main ():
535
+ global PERFITER
462
536
args = get_args ()
463
537
logging .basicConfig (level = logging .get_verbosity_level (args .verbose ))
464
538
if args .debug :
@@ -477,6 +551,7 @@ def main():
477
551
478
552
failed = 0
479
553
count = 0
554
+ PERFITER = args .perfiter
480
555
for test in test_keys :
481
556
logger .info ("===================================" )
482
557
@@ -520,7 +595,8 @@ def main():
520
595
for test in test_keys :
521
596
t = tests [test ]
522
597
if t .perf :
523
- f .write ("{},{},{}\n " .format (test , t .tf_runtime , t .onnx_runtime ))
598
+ # Report perf in ms per inference
599
+ f .write ("{},{},{}\n " .format (test , t .tf_runtime * 1000 / PERFITER , t .onnx_runtime * 1000 / PERFITER ))
524
600
return failed
525
601
526
602
0 commit comments