17
17
import yaml
18
18
from tensorflow .core .framework import graph_pb2
19
19
from tf2onnx .tfonnx import process_tf_graph
20
+ from tensorflow .python .framework .graph_util import convert_variables_to_constants
20
21
21
22
TMPPATH = tempfile .mkdtemp ()
22
23
PERFITER = 1000
@@ -77,11 +78,29 @@ def node_name(name):
77
78
return name
78
79
79
80
81
+ def freeze_session (sess , keep_var_names = None , output_names = None , clear_devices = True ):
82
+ """Freezes the state of a session into a pruned computation graph."""
83
+ output_names = [i .replace (":0" , "" ) for i in output_names ]
84
+ graph = sess .graph
85
+ with graph .as_default ():
86
+ freeze_var_names = list (set (v .op .name for v in tf .global_variables ()).difference (keep_var_names or []))
87
+ output_names = output_names or []
88
+ output_names += [v .op .name for v in tf .global_variables ()]
89
+ input_graph_def = graph .as_graph_def ()
90
+ if clear_devices :
91
+ for node in input_graph_def .node :
92
+ node .device = ""
93
+ frozen_graph = convert_variables_to_constants (sess , input_graph_def ,
94
+ output_names , freeze_var_names )
95
+ return frozen_graph
96
+
97
+
80
98
class Test (object ):
81
99
cache_dir = None
82
100
83
101
def __init__ (self , url , local , make_input , input_names , output_names ,
84
- disabled = False , more_inputs = None , rtol = 0.01 , atol = 0. , check_only_shape = False ):
102
+ disabled = False , more_inputs = None , rtol = 0.01 , atol = 0. ,
103
+ check_only_shape = False , model_type = "frozen" ):
85
104
self .url = url
86
105
self .make_input = make_input
87
106
self .local = local
@@ -95,6 +114,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
95
114
self .perf = None
96
115
self .tf_runtime = 0
97
116
self .onnx_runtime = 0
117
+ self .model_type = model_type
98
118
99
119
def download_file (self ):
100
120
"""Download file from url."""
@@ -231,8 +251,18 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
231
251
model_path = os .path .join (dir_name , self .local )
232
252
else :
233
253
model_path = self .local
254
+ dir_name = os .path .dirname (self .local )
234
255
print ("\t downloaded" , model_path )
235
256
257
+ # if the input model is a checkpoint, convert it to a frozen model
258
+ if self .model_type in ["checkpoint" ]:
259
+ saver = tf .train .import_meta_graph (model_path )
260
+ with tf .Session () as sess :
261
+ saver .restore (sess , model_path [:- 5 ])
262
+ frozen_graph = freeze_session (sess , output_names = self .output_names )
263
+ tf .train .write_graph (frozen_graph , dir_name , "frozen.pb" , as_text = False )
264
+ model_path = os .path .join (dir_name , "frozen.pb" )
265
+
236
266
inputs = self .make_input (self .input_names )
237
267
if self .more_inputs :
238
268
for k , v in self .more_inputs .items ():
@@ -314,7 +344,7 @@ def tests_from_yaml(fname):
314
344
input_func = v .get ("input_get" )
315
345
input_func = _INPUT_FUNC_MAPPING [input_func ]
316
346
kwargs = {}
317
- for kw in ["rtol" , "atol" , "disabled" , "more_inputs" , "check_only_shape" ]:
347
+ for kw in ["rtol" , "atol" , "disabled" , "more_inputs" , "check_only_shape" , "model_type" ]:
318
348
if v .get (kw ) is not None :
319
349
kwargs [kw ] = v [kw ]
320
350
0 commit comments