Skip to content

Commit d5e56d5

Browse files
authored
Merge pull request #1077 from xadupre/profile
Add a script to profile the conversion of a model
2 parents 41b497e + ee4ab03 commit d5e56d5

File tree

5 files changed

+181
-12
lines changed

5 files changed

+181
-12
lines changed

tests/test_profile.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Unit Tests for Benchmarks."""
5+
import os
6+
import subprocess
7+
from backend_test_base import Tf2OnnxBackendTestBase
8+
from common import (
9+
check_opset_min_version, check_tf_min_version,
10+
unittest_main, check_onnxruntime_min_version
11+
)
12+
13+
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
14+
# pylint: disable=invalid-name
15+
# pylint: enable=invalid-name
16+
17+
class ProfileTests(Tf2OnnxBackendTestBase):
18+
19+
folder = os.path.join(os.path.dirname(__file__), '..', 'tools')
20+
21+
@check_tf_min_version("2.0")
22+
@check_opset_min_version(12)
23+
@check_onnxruntime_min_version('1.4.0')
24+
def test_profile_conversion_time(self):
25+
filename = os.path.join(ProfileTests.folder, 'profile_conversion_time.py')
26+
proc = subprocess.Popen(
27+
["python", filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
28+
try:
29+
outs = proc.communicate(timeout=15)[0]
30+
except subprocess.TimeoutExpired:
31+
proc.kill()
32+
return
33+
assert b"Profile complete." in outs or outs == b''
34+
35+
36+
if __name__ == '__main__':
37+
unittest_main()

tf2onnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def get_tensor_value(self, as_list=True):
313313
when as_list=True, return 1, type is <class 'int'>.
314314
"""
315315
if not self.is_const():
316-
raise ValueError("get tensor value: {} must be Const".format(self.name))
316+
raise ValueError("get tensor value: '{}' must be Const".format(self.name))
317317

318318
t = self.get_attr("value")
319319
if t:

tf2onnx/onnx_opset/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def version_1(cls, ctx, node, **kwargs):
3838
# const we make it an attribute.
3939
seed = node.get_attr("seed")
4040
node.set_attr("seed", float(seed.f))
41-
if len(node.input) > 0:
41+
if len(node.input) > 0 and node.inputs[0].is_const():
4242
shape = node.inputs[0].get_tensor_value()
4343
ctx.remove_input(node, node.input[0], 0)
4444
node.set_attr("shape", shape)

tf2onnx/tfonnx.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def tensorflow_onnx_mapping(g, ops_mapping):
286286
func(g, node, **kwargs)
287287
node.skip_conversion = True
288288
except Exception as ex:
289-
logger.error("Failed to convert node %s\n%s", node.name, node.summary, exc_info=1)
289+
logger.error("Failed to convert node %r (fct=%r)\n%r",
290+
node.name, func, node.summary, exc_info=1)
290291
exceptions.append(ex)
291292

292293
return mapped_op, unmapped_op, exceptions
@@ -486,15 +487,30 @@ def compat_handler(ctx, node, **kwargs):
486487

487488
# pre-processing graph rewrites
488489
# bi-directional re-writer should be placed after single directional re-writer
489-
rewriters = [rewrite_constant_fold, rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten,
490-
rewrite_random_uniform, rewrite_random_uniform_fold_const,
491-
rewrite_random_normal, rewrite_dropout, rewrite_eye,
492-
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
493-
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
494-
rewrite_single_direction_gru, rewrite_bi_direction_gru,
495-
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond,
496-
rewrite_biasadd_with_conv2d, rewrite_gemm
497-
]
490+
rewriters = [
491+
# single directional
492+
rewrite_constant_fold,
493+
rewrite_quantize_and_dequantize,
494+
rewrite_transpose,
495+
rewrite_flatten,
496+
rewrite_random_uniform,
497+
rewrite_random_uniform_fold_const,
498+
rewrite_random_normal,
499+
rewrite_dropout,
500+
rewrite_eye,
501+
rewrite_leakyrelu,
502+
rewrite_thresholded_relu,
503+
rewrite_conv2d_with_pad,
504+
rewrite_single_direction_lstm,
505+
# bi-directional
506+
rewrite_bi_direction_lstm,
507+
rewrite_single_direction_gru,
508+
rewrite_bi_direction_gru,
509+
rewrite_custom_rnn_cell,
510+
rewrite_generic_loop, rewrite_cond,
511+
rewrite_biasadd_with_conv2d,
512+
rewrite_gemm,
513+
]
498514

499515
if custom_rewriter is not None:
500516
rewriters.extend(custom_rewriter)

tools/profile_conversion_time.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# coding: utf-8
2+
"""
3+
Profiles the conversion of a Keras model.
4+
"""
5+
import sys
6+
import cProfile
7+
from pstats import SortKey, Stats
8+
import io
9+
import argparse
10+
import tensorflow as tf
11+
from tensorflow.keras.applications import MobileNet, EfficientNetB2
12+
from tf2onnx import tfonnx
13+
try:
14+
from pyinstrument import Profiler
15+
except ImportError:
16+
Profiler = None
17+
18+
19+
def spy_model(name):
20+
"Creates the model."
21+
with tf.compat.v1.Session(graph=tf.Graph()) as session:
22+
if name == "MobileNet":
23+
model = MobileNet()
24+
elif name == "EfficientNetB2":
25+
model = EfficientNetB2()
26+
else:
27+
raise ValueError("Unknown model name %r." % name)
28+
29+
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
30+
sess=session,
31+
input_graph_def=session.graph_def,
32+
output_node_names=[model.output.op.name])
33+
34+
return graph_def, model
35+
36+
37+
def spy_convert(graph_def, model):
38+
"Converts the model."
39+
with tf.Graph().as_default() as graph:
40+
tf.import_graph_def(graph_def=graph_def, name='')
41+
42+
def spy_convert_in():
43+
return tfonnx.process_tf_graph(
44+
tf_graph=graph, input_names=[model.input.name],
45+
output_names=[model.output.name])
46+
47+
spy_convert_in()
48+
49+
50+
def create(name):
51+
"Creates the model."
52+
graph_def, model = spy_model(name)
53+
return graph_def, model
54+
55+
56+
def convert(graph_def, model):
57+
"Converts the model."
58+
spy_convert(graph_def, model)
59+
60+
61+
def profile(profiler="none", name="MobileNet", show_all=False):
62+
"""
63+
Profiles the conversion of a model.
64+
65+
:param profiler: one among none, spy, pyinstrument, cProfile
66+
:param name: model to profile, MobileNet, EfficientNetB2
67+
:param showall: used by pyinstrument to show all functions
68+
"""
69+
print("create(%r, %r)" % (profiler, name))
70+
graph_def, model = create(name)
71+
print("profile(%r, %r)" % (profiler, name))
72+
if profiler == 'none':
73+
convert(graph_def, model)
74+
elif profiler == "spy":
75+
# py-spy record -r 10 -o profile.svg -- python conversion_time.py spy
76+
convert(graph_def, model)
77+
elif profiler == "pyinstrument":
78+
if Profiler is None:
79+
raise ImportError("pyinstrument is not installed")
80+
profiler = Profiler(interval=0.0001)
81+
profiler.start()
82+
convert(graph_def, model)
83+
profiler.stop()
84+
print(profiler.output_text(unicode=False, color=False, show_all=show_all))
85+
elif profiler == "cProfile":
86+
pr = cProfile.Profile()
87+
pr.enable()
88+
convert(graph_def, model)
89+
pr.disable()
90+
s = io.StringIO()
91+
sortby = SortKey.CUMULATIVE
92+
ps = Stats(pr, stream=s).sort_stats(sortby)
93+
ps.print_stats()
94+
print(s.getvalue())
95+
else:
96+
raise ValueError("Unknown profiler %r." % profiler)
97+
98+
99+
def main(args):
100+
parser = argparse.ArgumentParser(description='Process some integers.')
101+
parser.add_argument('--profiler', default='none',
102+
choices=['none', 'spy', 'pyinstrument', 'cProfile'],
103+
help='a profiler')
104+
parser.add_argument('--name', default="MobileNet",
105+
choices=['MobileNet', 'EfficientNetB2'],
106+
help="a model")
107+
parser.add_argument('--showall', type=bool, default=False,
108+
help="used by pyinstrument to show all functions")
109+
res = parser.parse_args(args)
110+
profile(res.profiler, res.name, res.showall)
111+
112+
113+
if __name__ == '__main__':
114+
print('Begin profiling with', sys.argv[1:])
115+
main(sys.argv[1:])
116+
print('Profile complete.')

0 commit comments

Comments
 (0)