14
14
import argparse
15
15
import os
16
16
import re
17
+ import shutil
17
18
import sys
18
19
import tarfile
19
20
import tempfile
@@ -187,14 +188,17 @@ def __init__(self, url, local, input_func, input_names, output_names,
187
188
check_only_shape = False , model_type = "frozen" , force_input_shape = False ,
188
189
skip_tensorflow = False , opset_constraints = None , tf_min_version = None , tag = None ,
189
190
skip_conversion = False , converted_model = None , signature_def = None , concrete_function = None ,
190
- large_model = False , structured_outputs = None , run_tf_frozen = None , use_custom_ops = False ):
191
+ large_model = False , structured_outputs = None , run_tf_frozen = None , use_custom_ops = False ,
192
+ ort_profile = None , tf_profile = None ):
191
193
self .url = url
192
194
self .input_func = input_func
193
195
self .local = local
194
196
self .input_names = input_names
195
197
self .output_names = output_names
196
198
self .disabled = disabled
197
199
self .large_model = large_model
200
+ self .ort_profile = ort_profile
201
+ self .tf_profile = tf_profile
198
202
self .use_custom_ops = use_custom_ops
199
203
if run_tf_frozen is None :
200
204
run_tf_frozen = not self .large_model
@@ -324,13 +328,14 @@ def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_st
324
328
as_text = utils .is_debug_mode (),
325
329
external_tensor_storage = external_tensor_storage )
326
330
logger .info ("Model saved to %s" , model_path )
331
+ opt = rt .SessionOptions ()
327
332
if self .use_custom_ops :
328
333
from ortcustomops import get_library_path
329
- opt = rt .SessionOptions ()
330
334
opt .register_custom_ops_library (get_library_path ())
331
335
m = rt .InferenceSession (model_path , opt )
332
- else :
333
- m = rt .InferenceSession (model_path )
336
+ if self .ort_profile is not None :
337
+ opt .enable_profiling = True
338
+ m = rt .InferenceSession (model_path , opt )
334
339
results = m .run (outputs , inputs )
335
340
if self .perf :
336
341
n = 0
@@ -342,6 +347,9 @@ def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_st
342
347
n += PERF_STEP
343
348
self .onnx_runtime = 1000 * (time .time () - start ) / n
344
349
logger .info ("ORT perf {:.2f}ms/inference, n={}" .format (self .onnx_runtime , n ))
350
+ if self .ort_profile is not None :
351
+ tmp_path = m .end_profiling ()
352
+ shutil .move (tmp_path , self .ort_profile )
345
353
return results
346
354
347
355
@staticmethod
@@ -449,10 +457,14 @@ def run_tflite():
449
457
n = 0
450
458
start = time .time ()
451
459
stop = start + PERF_TIME
460
+ if self .tf_profile is not None :
461
+ tf .profiler .experimental .start (self .tf_profile )
452
462
while time .time () < stop :
453
463
for _ in range (PERF_STEP ):
454
464
_ = concrete_func (** inputs )
455
465
n += PERF_STEP
466
+ if self .tf_profile is not None :
467
+ tf .profiler .experimental .stop ()
456
468
self .tf_runtime = 1000 * (time .time () - start ) / n
457
469
logger .info ("TF perf {:.2f}ms/inference, n={}" .format (self .tf_runtime , n ))
458
470
logger .info ("TensorFlow OK" )
@@ -497,7 +509,11 @@ def run_tflite():
497
509
if self .skip_tensorflow :
498
510
logger .info ("TensorFlow SKIPPED" )
499
511
elif self .run_tf_frozen :
512
+ if self .tf_profile is not None :
513
+ tf .profiler .experimental .start (self .tf_profile )
500
514
tf_results = self .run_tensorflow (sess , inputs )
515
+ if self .tf_profile is not None :
516
+ tf .profiler .experimental .stop ()
501
517
logger .info ("TensorFlow OK" )
502
518
tf_graph = sess .graph
503
519
@@ -690,7 +706,7 @@ def load_tests_from_yaml(path):
690
706
for kw in ["rtol" , "atol" , "ptol" , "disabled" , "check_only_shape" , "model_type" , "concrete_function" ,
691
707
"skip_tensorflow" , "force_input_shape" , "tf_min_version" , "tag" , "skip_conversion" ,
692
708
"converted_model" , "signature_def" , "large_model" , "structured_outputs" , "run_tf_frozen" ,
693
- "use_custom_ops" , "dequantize" ]:
709
+ "use_custom_ops" , "dequantize" , "ort_profile" , "tf_profile" ]:
694
710
if settings .get (kw ) is not None :
695
711
kwargs [kw ] = settings [kw ]
696
712
0 commit comments