Skip to content

Commit 5af5c52

Browse files
committed
add script to profile
1 parent 080c7a4 commit 5af5c52

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

benchmarks/profile_conversion_time.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# coding: utf-8
2+
"""
3+
Profiles the conversion of a Keras model.
4+
"""
5+
import fire
6+
import tensorflow as tf
7+
from tf2onnx import tfonnx
8+
from tensorflow.keras.applications import MobileNet
9+
from tensorflow.keras.applications import EfficientNetB2
10+
11+
12+
def spy_model(k, name):
13+
"Creates the model."
14+
with tf.compat.v1.Session(graph=tf.Graph()) as session:
15+
if name == "MobileNet":
16+
model = MobileNet()
17+
elif name == "EfficientNetB2":
18+
model = EfficientNetB2()
19+
else:
20+
raise ValueError("Unknown model name %r." % name)
21+
22+
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
23+
sess=session,
24+
input_graph_def=session.graph_def,
25+
output_node_names=[model.output.op.name])
26+
27+
return graph_def, model
28+
29+
30+
def spy_convert(graph_def, model):
31+
"Converts the model."
32+
with tf.Graph().as_default() as graph:
33+
tf.import_graph_def(graph_def=graph_def, name='')
34+
35+
def spy_convert_in():
36+
return tfonnx.process_tf_graph(tf_graph=graph,
37+
input_names=[model.input.name],
38+
output_names=[model.output.name])
39+
40+
spy_convert_in()
41+
42+
43+
def create(name, module):
44+
"Creates the model."
45+
if module == 'tf.keras':
46+
mod = tf.keras
47+
else:
48+
raise ValueError("Unknown module '{}'.".format(module))
49+
graph_def, model = spy_model(mod, name)
50+
return graph_def, model
51+
52+
53+
def convert(graph_def, model):
54+
"Converts the model."
55+
spy_convert(graph_def, model)
56+
57+
58+
def profile(profiler="pyinstrument", name="MobileNet", show_all=False,
59+
module='tf.keras'):
60+
"""
61+
Profiles the conversion of a model.
62+
63+
:param profiler: one among spy, pyinstrument, cProfile
64+
:param name: model to profile, MobileNet, EfficientNetB2
65+
:param show_all: use by pyinstrument to show all functions
66+
"""
67+
print("create(%r, %r, %r)" % (profiler, name, module))
68+
graph_def, model = create(name, module)
69+
print("profile(%r, %r, %r)" % (profiler, name, module))
70+
if profiler == "spy":
71+
# py-spy record -r 10 -o profile.svg -- python conversion_time.py spy
72+
convert(graph_def, model)
73+
elif profiler == "pyinstrument":
74+
from pyinstrument import Profiler
75+
76+
profiler = Profiler(interval=0.0001)
77+
profiler.start()
78+
79+
convert(graph_def, model)
80+
81+
profiler.stop()
82+
print(profiler.output_text(unicode=False, color=False, show_all=show_all))
83+
elif profiler == "cProfile":
84+
import cProfile, pstats, io
85+
from pstats import SortKey
86+
87+
pr = cProfile.Profile()
88+
pr.enable()
89+
convert(graph_def, model)
90+
pr.disable()
91+
s = io.StringIO()
92+
sortby = SortKey.CUMULATIVE
93+
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
94+
ps.print_stats()
95+
print(s.getvalue())
96+
else:
97+
raise ValueError("Unknown profiler %r." % profiler)
98+
99+
100+
if __name__ == '__main__':
101+
fire.Fire(profile)

ci_build/azure_pipelines/templates/unit_test.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,13 @@ steps:
1414
condition: succeededOrFailed()
1515
env:
1616
CI_ONNX_OPSET: '${{ onnx_opset }}'
17+
18+
- bash: |
19+
export TF2ONNX_TEST_BACKEND=$CI_ONNX_BACKEND
20+
export TF2ONNX_TEST_OPSET=$CI_ONNX_OPSET
21+
python benchmarks/profile_conversion_time.py
22+
timeoutInMinutes: 15
23+
displayName: ${{ format('Run profile_conversion_time.py - Opset{0}', onnx_opset) }}
24+
condition: succeededOrFailed()
25+
env:
26+
CI_ONNX_OPSET: '${{ onnx_opset }}'

0 commit comments

Comments
 (0)