Skip to content

Commit 1813ca2

Browse files
Add option to run_pretrained_models to save ort/tf profile (#1509)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 8281254 commit 1813ca2

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

tests/run_pretrained_models.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import argparse
1515
import os
1616
import re
17+
import shutil
1718
import sys
1819
import tarfile
1920
import tempfile
@@ -187,14 +188,17 @@ def __init__(self, url, local, input_func, input_names, output_names,
187188
check_only_shape=False, model_type="frozen", force_input_shape=False,
188189
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None,
189190
skip_conversion=False, converted_model=None, signature_def=None, concrete_function=None,
190-
large_model=False, structured_outputs=None, run_tf_frozen=None, use_custom_ops=False):
191+
large_model=False, structured_outputs=None, run_tf_frozen=None, use_custom_ops=False,
192+
ort_profile=None, tf_profile=None):
191193
self.url = url
192194
self.input_func = input_func
193195
self.local = local
194196
self.input_names = input_names
195197
self.output_names = output_names
196198
self.disabled = disabled
197199
self.large_model = large_model
200+
self.ort_profile = ort_profile
201+
self.tf_profile = tf_profile
198202
self.use_custom_ops = use_custom_ops
199203
if run_tf_frozen is None:
200204
run_tf_frozen = not self.large_model
@@ -324,13 +328,14 @@ def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_st
324328
as_text=utils.is_debug_mode(),
325329
external_tensor_storage=external_tensor_storage)
326330
logger.info("Model saved to %s", model_path)
331+
opt = rt.SessionOptions()
327332
if self.use_custom_ops:
328333
from ortcustomops import get_library_path
329-
opt = rt.SessionOptions()
330334
opt.register_custom_ops_library(get_library_path())
331335
m = rt.InferenceSession(model_path, opt)
332-
else:
333-
m = rt.InferenceSession(model_path)
336+
if self.ort_profile is not None:
337+
opt.enable_profiling = True
338+
m = rt.InferenceSession(model_path, opt)
334339
results = m.run(outputs, inputs)
335340
if self.perf:
336341
n = 0
@@ -342,6 +347,9 @@ def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_st
342347
n += PERF_STEP
343348
self.onnx_runtime = 1000 * (time.time() - start) / n
344349
logger.info("ORT perf {:.2f}ms/inference, n={}".format(self.onnx_runtime, n))
350+
if self.ort_profile is not None:
351+
tmp_path = m.end_profiling()
352+
shutil.move(tmp_path, self.ort_profile)
345353
return results
346354

347355
@staticmethod
@@ -449,10 +457,14 @@ def run_tflite():
449457
n = 0
450458
start = time.time()
451459
stop = start + PERF_TIME
460+
if self.tf_profile is not None:
461+
tf.profiler.experimental.start(self.tf_profile)
452462
while time.time() < stop:
453463
for _ in range(PERF_STEP):
454464
_ = concrete_func(**inputs)
455465
n += PERF_STEP
466+
if self.tf_profile is not None:
467+
tf.profiler.experimental.stop()
456468
self.tf_runtime = 1000 * (time.time() - start) / n
457469
logger.info("TF perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n))
458470
logger.info("TensorFlow OK")
@@ -497,7 +509,11 @@ def run_tflite():
497509
if self.skip_tensorflow:
498510
logger.info("TensorFlow SKIPPED")
499511
elif self.run_tf_frozen:
512+
if self.tf_profile is not None:
513+
tf.profiler.experimental.start(self.tf_profile)
500514
tf_results = self.run_tensorflow(sess, inputs)
515+
if self.tf_profile is not None:
516+
tf.profiler.experimental.stop()
501517
logger.info("TensorFlow OK")
502518
tf_graph = sess.graph
503519

@@ -690,7 +706,7 @@ def load_tests_from_yaml(path):
690706
for kw in ["rtol", "atol", "ptol", "disabled", "check_only_shape", "model_type", "concrete_function",
691707
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag", "skip_conversion",
692708
"converted_model", "signature_def", "large_model", "structured_outputs", "run_tf_frozen",
693-
"use_custom_ops", "dequantize"]:
709+
"use_custom_ops", "dequantize", "ort_profile", "tf_profile"]:
694710
if settings.get(kw) is not None:
695711
kwargs[kw] = settings[kw]
696712

0 commit comments

Comments
 (0)