2
2
"""
3
3
Profiles the conversion of a Keras model.
4
4
"""
5
+ import cProfile , pstats , io
6
+ from pstats import SortKey
5
7
import fire
6
8
import tensorflow as tf
7
9
from tf2onnx import tfonnx
8
- from tensorflow .keras .applications import MobileNet
9
- from tensorflow .keras .applications import EfficientNetB2
10
+ from tensorflow .keras .applications import MobileNet , EfficientNetB2
11
+ try :
12
+ from pyinstrument import Profiler
13
+ except ImportError :
14
+ Profiler = None
10
15
11
16
12
- def spy_model (k , name ):
17
+ def spy_model (name ):
13
18
"Creates the model."
14
19
with tf .compat .v1 .Session (graph = tf .Graph ()) as session :
15
20
if name == "MobileNet" :
@@ -40,13 +45,9 @@ def spy_convert_in():
40
45
spy_convert_in ()
41
46
42
47
43
- def create (name , module ):
48
+ def create (name ):
44
49
"Creates the model."
45
- if module == 'tf.keras' :
46
- mod = tf .keras
47
- else :
48
- raise ValueError ("Unknown module '{}'." .format (module ))
49
- graph_def , model = spy_model (mod , name )
50
+ graph_def , model = spy_model (name )
50
51
return graph_def , model
51
52
52
53
@@ -55,37 +56,31 @@ def convert(graph_def, model):
55
56
spy_convert (graph_def , model )
56
57
57
58
58
- def profile (profiler = "none" , name = "MobileNet" , show_all = False ,
59
- module = 'tf.keras' ):
59
+ def profile (profiler = "none" , name = "MobileNet" , show_all = False ):
60
60
"""
61
61
Profiles the conversion of a model.
62
62
63
63
:param profiler: one among none, spy, pyinstrument, cProfile
64
64
:param name: model to profile, MobileNet, EfficientNetB2
65
65
:param show_all: use by pyinstrument to show all functions
66
66
"""
67
- print ("create(%r, %r, %r )" % (profiler , name , module ))
68
- graph_def , model = create (name , module )
69
- print ("profile(%r, %r, %r )" % (profiler , name , module ))
67
+ print ("create(%r, %r)" % (profiler , name ))
68
+ graph_def , model = create (name )
69
+ print ("profile(%r, %r)" % (profiler , name ))
70
70
if profiler == 'none' :
71
71
convert (graph_def , model )
72
72
elif profiler == "spy" :
73
73
# py-spy record -r 10 -o profile.svg -- python conversion_time.py spy
74
74
convert (graph_def , model )
75
75
elif profiler == "pyinstrument" :
76
- from pyinstrument import Profiler
77
-
76
+ if Profiler is None :
77
+ raise ImportError ( "pyinstrument is not installed" )
78
78
profiler = Profiler (interval = 0.0001 )
79
79
profiler .start ()
80
-
81
80
convert (graph_def , model )
82
-
83
81
profiler .stop ()
84
82
print (profiler .output_text (unicode = False , color = False , show_all = show_all ))
85
83
elif profiler == "cProfile" :
86
- import cProfile , pstats , io
87
- from pstats import SortKey
88
-
89
84
pr = cProfile .Profile ()
90
85
pr .enable ()
91
86
convert (graph_def , model )
0 commit comments