16
16
import traceback
17
17
import zipfile
18
18
19
+ import PIL .Image
19
20
import numpy as np
20
21
import requests
21
22
import six
22
23
import tensorflow as tf
23
- from tensorflow .core .framework import graph_pb2
24
- from tensorflow .python .framework .graph_util import convert_variables_to_constants
25
24
# contrib ops are registered only when the module is imported, the following import statement is needed,
26
25
# otherwise tf runtime error will show up when the tf model is restored from pb file because of un-registered ops.
27
26
import tensorflow .contrib .rnn # pylint: disable=unused-import
28
27
import yaml
29
- import PIL . Image
28
+ from tensorflow . core . framework import graph_pb2
30
29
31
30
import tf2onnx
31
+ from tf2onnx import loader
32
32
from tf2onnx import utils
33
33
from tf2onnx .graph import GraphUtil
34
34
from tf2onnx .tfonnx import process_tf_graph
@@ -74,23 +74,6 @@ def get_ramp(shape):
74
74
}
75
75
76
76
77
- def freeze_session (sess , keep_var_names = None , output_names = None , clear_devices = True ):
78
- """Freezes the state of a session into a pruned computation graph."""
79
- output_names = [i .replace (":0" , "" ) for i in output_names ]
80
- graph = sess .graph
81
- with graph .as_default ():
82
- freeze_var_names = list (set (v .op .name for v in tf .global_variables ()).difference (keep_var_names or []))
83
- output_names = output_names or []
84
- output_names += [v .op .name for v in tf .global_variables ()]
85
- input_graph_def = graph .as_graph_def ()
86
- if clear_devices :
87
- for node in input_graph_def .node :
88
- node .device = ""
89
- frozen_graph = convert_variables_to_constants (sess , input_graph_def ,
90
- output_names , freeze_var_names )
91
- return frozen_graph
92
-
93
-
94
77
class Test (object ):
95
78
"""Main Test class."""
96
79
@@ -236,45 +219,15 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
236
219
dir_name = os .path .dirname (self .local )
237
220
print ("\t downloaded" , model_path )
238
221
222
+ inputs = list (self .input_names .keys ())
223
+ outputs = self .output_names
239
224
if self .model_type in ["checkpoint" ]:
240
- #
241
- # if the input model is a checkpoint, convert it to a frozen model
242
- saver = tf .train .import_meta_graph (model_path )
243
- with tf .Session () as sess :
244
- saver .restore (sess , model_path [:- 5 ])
245
- frozen_graph = freeze_session (sess , output_names = self .output_names )
246
- tf .train .write_graph (frozen_graph , dir_name , "frozen.pb" , as_text = False )
247
- model_path = os .path .join (dir_name , "frozen.pb" )
225
+ graph_def , inputs , outputs = loader .from_checkpoint (model_path , inputs , outputs )
248
226
elif self .model_type in ["saved_model" ]:
249
- try :
250
- from tensorflow .contrib .saved_model .python .saved_model import signature_def_utils
251
- get_signature_def = lambda meta_graph_def , k : \
252
- signature_def_utils .get_signature_def_by_key (meta_graph_def , k )
253
- except ImportError :
254
- # TF1.12 changed the api
255
- get_signature_def = lambda meta_graph_def , k : meta_graph_def .signature_def [k ]
256
-
257
- # saved_model format - convert to checkpoint
258
- with tf .Session () as sess :
259
- meta_graph_def = tf .saved_model .loader .load (sess , [tf .saved_model .tag_constants .SERVING ], model_path )
260
- inputs = {}
261
- outputs = {}
262
- for k in meta_graph_def .signature_def .keys ():
263
- inputs_tensor_info = get_signature_def (meta_graph_def , k ).inputs
264
- for _ , input_tensor in sorted (inputs_tensor_info .items ()):
265
- inputs [input_tensor .name ] = sess .graph .get_tensor_by_name (input_tensor .name )
266
- outputs_tensor_info = get_signature_def (meta_graph_def , k ).outputs
267
- for _ , output_tensor in sorted (outputs_tensor_info .items ()):
268
- outputs [output_tensor .name ] = sess .graph .get_tensor_by_name (output_tensor .name )
269
- # freeze uses the node name derived from output:0 so only pass in output:0;
270
- # it will provide all outputs of that node.
271
- for o in list (outputs .keys ()):
272
- if not o .endswith (":0" ):
273
- del outputs [o ]
274
- frozen_graph = freeze_session (sess , output_names = list (outputs .keys ()))
275
- tf .train .write_graph (frozen_graph , dir_name , "frozen.pb" , as_text = False )
276
- model_path = os .path .join (dir_name , "frozen.pb" )
277
-
227
+ graph_def , inputs , outputs = loader .from_saved_model (model_path , inputs , outputs )
228
+ else :
229
+ graph_def , inputs , outputs = loader .from_graphdef (model_path , inputs , outputs )
230
+
278
231
# create the input data
279
232
inputs = {}
280
233
for k , v in self .input_names .items ():
@@ -285,10 +238,6 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
285
238
if self .more_inputs :
286
239
for k , v in self .more_inputs .items ():
287
240
inputs [k ] = v
288
- tf .reset_default_graph ()
289
- graph_def = graph_pb2 .GraphDef ()
290
- with open (model_path , "rb" ) as f :
291
- graph_def .ParseFromString (f .read ())
292
241
293
242
graph_def = tf2onnx .tfonnx .tf_optimize (inputs .keys (), self .output_names , graph_def , fold_const )
294
243
shape_override = {}
0 commit comments