Skip to content

Commit feab45c

Browse files
committed
lint
1 parent febb438 commit feab45c

File tree

1 file changed

+16
-21
lines changed

1 file changed

+16
-21
lines changed

tools/profile_conversion_time.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
"""
33
Profiles the conversion of a Keras model.
44
"""
5+
import cProfile, pstats, io
6+
from pstats import SortKey
57
import fire
68
import tensorflow as tf
79
from tf2onnx import tfonnx
8-
from tensorflow.keras.applications import MobileNet
9-
from tensorflow.keras.applications import EfficientNetB2
10+
from tensorflow.keras.applications import MobileNet, EfficientNetB2
11+
try:
12+
from pyinstrument import Profiler
13+
except ImportError:
14+
Profiler = None
1015

1116

12-
def spy_model(k, name):
17+
def spy_model(name):
1318
"Creates the model."
1419
with tf.compat.v1.Session(graph=tf.Graph()) as session:
1520
if name == "MobileNet":
@@ -40,13 +45,9 @@ def spy_convert_in():
4045
spy_convert_in()
4146

4247

43-
def create(name, module):
48+
def create(name):
4449
"Creates the model."
45-
if module == 'tf.keras':
46-
mod = tf.keras
47-
else:
48-
raise ValueError("Unknown module '{}'.".format(module))
49-
graph_def, model = spy_model(mod, name)
50+
graph_def, model = spy_model(name)
5051
return graph_def, model
5152

5253

@@ -55,37 +56,31 @@ def convert(graph_def, model):
5556
spy_convert(graph_def, model)
5657

5758

58-
def profile(profiler="none", name="MobileNet", show_all=False,
59-
module='tf.keras'):
59+
def profile(profiler="none", name="MobileNet", show_all=False):
6060
"""
6161
Profiles the conversion of a model.
6262
6363
:param profiler: one among none, spy, pyinstrument, cProfile
6464
:param name: model to profile, MobileNet, EfficientNetB2
6565
:param show_all: use by pyinstrument to show all functions
6666
"""
67-
print("create(%r, %r, %r)" % (profiler, name, module))
68-
graph_def, model = create(name, module)
69-
print("profile(%r, %r, %r)" % (profiler, name, module))
67+
print("create(%r, %r)" % (profiler, name))
68+
graph_def, model = create(name)
69+
print("profile(%r, %r)" % (profiler, name))
7070
if profiler == 'none':
7171
convert(graph_def, model)
7272
elif profiler == "spy":
7373
# py-spy record -r 10 -o profile.svg -- python conversion_time.py spy
7474
convert(graph_def, model)
7575
elif profiler == "pyinstrument":
76-
from pyinstrument import Profiler
77-
76+
if Profiler is None:
77+
raise ImportError("pyinstrument is not installed")
7878
profiler = Profiler(interval=0.0001)
7979
profiler.start()
80-
8180
convert(graph_def, model)
82-
8381
profiler.stop()
8482
print(profiler.output_text(unicode=False, color=False, show_all=show_all))
8583
elif profiler == "cProfile":
86-
import cProfile, pstats, io
87-
from pstats import SortKey
88-
8984
pr = cProfile.Profile()
9085
pr.enable()
9186
convert(graph_def, model)

0 commit comments

Comments
 (0)