@@ -83,16 +83,28 @@ def not_implemented_tf_placeholder(*args, **kwargs):
83
83
extract_sub_graph = tf .graph_util .extract_sub_graph
84
84
85
85
86
+ def inputs_without_resource (sess , input_names ):
87
+ try :
88
+ new_input_names = []
89
+ for n in input_names :
90
+ t = sess .graph .get_tensor_by_name (n )
91
+ if t .dtype != tf .dtypes .resource :
92
+ new_input_names .append (n )
93
+ input_names = new_input_names
94
+ except : # pylint: disable=bare-except
95
+ pass
96
+ return input_names
97
+
98
+
86
99
def from_function (func , input_names , output_names ):
87
100
frozen_func = convert_variables_to_constants_v2 (func , lower_control_flow = False )
88
101
graph_def = frozen_func .graph .as_graph_def (add_shapes = True )
89
- # output_tensors = { i.name: i for i in frozen_func.outputs}
102
+ # output_names = [ i.name for i in frozen_func.outputs]
90
103
tf_reset_default_graph ()
91
104
with tf_session () as sess :
92
105
tf .import_graph_def (graph_def , name = '' )
93
- input_tensors = {i : sess .graph .get_tensor_by_name (i ) for i in input_names }
94
- output_tensors = {i : sess .graph .get_tensor_by_name (i ) for i in output_names }
95
- graph_def = tf_optimize (input_tensors , output_tensors , graph_def )
106
+ input_names = inputs_without_resource (sess , input_names )
107
+ graph_def = tf_optimize (input_names , output_names , graph_def )
96
108
return graph_def
97
109
98
110
@@ -101,7 +113,6 @@ def freeze_session(sess, input_names=None, output_names=None):
101
113
output_node_names = [i .split (':' )[:- 1 ][0 ] for i in output_names ]
102
114
keep_var_names = [i .split (':' )[:- 1 ][0 ] for i in input_names ]
103
115
with sess .graph .as_default ():
104
- # freeze_var_names = list(set(v.op.name for v in tf_global_variables()).difference(keep_var_names or []))
105
116
output_node_names = output_node_names or []
106
117
output_node_names += [v .op .name for v in tf_global_variables ()]
107
118
output_node_names += keep_var_names
@@ -135,16 +146,16 @@ def from_graphdef(model_path, input_names, output_names):
135
146
with tf_gfile .GFile (model_path , 'rb' ) as f :
136
147
graph_def .ParseFromString (f .read ())
137
148
tf .import_graph_def (graph_def , name = '' )
149
+ input_names = inputs_without_resource (sess , input_names )
138
150
frozen_graph = freeze_session (sess , input_names = input_names , output_names = output_names )
139
151
input_names = remove_redundant_inputs (frozen_graph , input_names )
140
- inputs = {i : sess .graph .get_tensor_by_name (i ) for i in input_names }
141
- outputs = {i : sess .graph .get_tensor_by_name (i ) for i in output_names }
142
152
143
153
tf_reset_default_graph ()
144
154
with tf_session () as sess :
145
- frozen_graph = tf_optimize (inputs , outputs , frozen_graph )
155
+ input_names = inputs_without_resource (sess , input_names )
156
+ frozen_graph = tf_optimize (input_names , output_names , frozen_graph )
146
157
tf_reset_default_graph ()
147
- return frozen_graph , inputs , outputs
158
+ return frozen_graph , input_names , output_names
148
159
149
160
150
161
def from_checkpoint (model_path , input_names , output_names ):
@@ -156,23 +167,19 @@ def from_checkpoint(model_path, input_names, output_names):
156
167
saver = tf_import_meta_graph (model_path , clear_devices = True )
157
168
# restore from model_path minus the ".meta"
158
169
saver .restore (sess , model_path [:- 5 ])
170
+ input_names = inputs_without_resource (sess , input_names )
159
171
frozen_graph = freeze_session (sess , input_names = input_names , output_names = output_names )
160
172
input_names = remove_redundant_inputs (frozen_graph , input_names )
161
- inputs = {i : sess .graph .get_tensor_by_name (i ) for i in input_names }
162
- outputs = {i : sess .graph .get_tensor_by_name (i ) for i in output_names }
163
173
164
174
tf_reset_default_graph ()
165
175
with tf_session () as sess :
166
- frozen_graph = tf_optimize (inputs , outputs , frozen_graph )
176
+ frozen_graph = tf_optimize (input_names , output_names , frozen_graph )
167
177
tf_reset_default_graph ()
168
- return frozen_graph , inputs , outputs
178
+ return frozen_graph , input_names , output_names
169
179
170
180
171
181
def _from_saved_model_v1 (sess , model_path , input_names , output_names , signatures ):
172
182
"""Load tensorflow graph from saved_model."""
173
- # make sure we start with clean default graph
174
- inputs = {}
175
- outputs = {}
176
183
177
184
imported = tf .saved_model .loader .load (sess , [tf .saved_model .tag_constants .SERVING ], model_path )
178
185
for k in imported .signature_def .keys ():
@@ -189,24 +196,21 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures
189
196
# TF1.12 changed the api
190
197
get_signature_def = lambda meta_graph_def , k : meta_graph_def .signature_def [k ]
191
198
199
+ input_names = []
200
+ output_names = []
192
201
for k in signatures :
193
202
inputs_tensor_info = get_signature_def (imported , k ).inputs
194
203
for _ , input_tensor in inputs_tensor_info .items ():
195
- inputs [ input_tensor . name ] = sess . graph . get_tensor_by_name (input_tensor .name )
204
+ input_names . append (input_tensor .name )
196
205
outputs_tensor_info = get_signature_def (imported , k ).outputs
197
206
for _ , output_tensor in outputs_tensor_info .items ():
198
- outputs [output_tensor .name ] = sess .graph .get_tensor_by_name (output_tensor .name )
199
-
200
- frozen_graph = freeze_session (sess , input_names = list (inputs .keys ()), output_names = list (outputs .keys ()))
201
- return frozen_graph , inputs , outputs
207
+ output_names .append (output_tensor .name )
208
+ frozen_graph = freeze_session (sess , input_names = input_names , output_names = output_names )
209
+ return frozen_graph , input_names , output_names
202
210
203
211
204
212
def _from_saved_model_v2 (model_path , input_names , output_names , signatures ):
205
213
"""Load tensorflow graph from saved_model."""
206
- # make sure we start with clean default graph
207
- inputs = {}
208
- outputs = {}
209
-
210
214
imported = tf .saved_model .load (model_path ) # pylint: disable=no-value-for-parameter
211
215
212
216
# f = meta_graph_def.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@@ -217,34 +221,34 @@ def _from_saved_model_v2(model_path, input_names, output_names, signatures):
217
221
signatures .append (k )
218
222
for k in signatures :
219
223
concrete_func = imported .signatures [k ]
220
- inputs = {input_tensor .name : input_tensor for input_tensor in concrete_func .inputs }
221
- outputs = {output_tensor .name : output_tensor for output_tensor in concrete_func .outputs }
224
+ input_names = [input_tensor .name for input_tensor in concrete_func .inputs
225
+ if input_tensor .dtype != tf .dtypes .resource ]
226
+ output_names = [output_tensor .name for output_tensor in concrete_func .outputs
227
+ if output_tensor .dtype != tf .dtypes .resource ]
222
228
223
- frozen_graph = from_function (concrete_func , list ( inputs . keys ()), list ( outputs . keys ()) )
224
- return frozen_graph , inputs , outputs
229
+ frozen_graph = from_function (concrete_func , input_names , output_names )
230
+ return frozen_graph , input_names , output_names
225
231
226
232
227
233
def from_saved_model (model_path , input_names , output_names , signatures = None ):
228
234
"""Load tensorflow graph from saved_model."""
229
- # make sure we start with clean default graph
230
235
if signatures is None :
231
236
signatures = []
232
237
tf_reset_default_graph ()
233
238
if is_tf2 ():
234
- frozen_graph , inputs , outputs = \
239
+ frozen_graph , input_names , output_names = \
235
240
_from_saved_model_v2 (model_path , input_names , output_names , signatures )
236
- inputs = {k : v for k , v in inputs .items () if v .dtype != tf .dtypes .resource }
237
241
else :
238
242
with tf_session () as sess :
239
- frozen_graph , inputs , outputs = \
243
+ frozen_graph , input_names , output_names = \
240
244
_from_saved_model_v1 (sess , model_path , input_names , output_names , signatures )
241
245
242
246
if len (signatures ) > 1 :
243
247
logger .warning ("found multiple signatures %s in saved_model, pass --signature_def in command line" ,
244
248
signatures )
245
249
246
250
tf_reset_default_graph ()
247
- return frozen_graph , inputs , outputs
251
+ return frozen_graph , input_names , output_names
248
252
249
253
250
254
def from_keras (model_path , input_names , output_names ):
@@ -261,34 +265,35 @@ def from_keras(model_path, input_names, output_names):
261
265
262
266
function = _saving_utils .trace_model_call (keras_model )
263
267
concrete_func = function .get_concrete_function ()
264
- inputs = {input_tensor .name : input_tensor for input_tensor in concrete_func .inputs if
265
- input_tensor .name in input_names }
266
- outputs = {output_tensor .name : output_tensor for output_tensor in concrete_func .outputs if
267
- output_tensor .name in output_names }
268
- frozen_graph = from_function (concrete_func , list (inputs .keys ()), list (outputs .keys ()))
268
+ # allow to pass inputs and outputs from caller if we don't want all of them
269
+ input_names = [input_tensor .name for input_tensor in concrete_func .inputs
270
+ if input_tensor .dtype != tf .dtypes .resource ]
271
+ output_names = [output_tensor .name for output_tensor in concrete_func .outputs
272
+ if output_tensor .dtype != tf .dtypes .resource ]
273
+
274
+ frozen_graph = from_function (concrete_func , input_names , output_names )
269
275
else :
270
276
# Handles Keras when Eager mode is disabled.
271
277
_keras .backend .clear_session ()
272
278
_keras .backend .set_learning_phase (False )
273
279
keras_model = _keras .models .load_model (model_path , custom_objects )
280
+ # allow to pass inputs and outputs from caller if we don't want all of them
281
+ input_names = keras_model .inputs
282
+ output_names = keras_model .outputs
274
283
sess = _keras .backend .get_session ()
284
+ input_names = inputs_without_resource (sess , input_names )
275
285
frozen_graph = freeze_session (sess , input_names = input_names , output_names = output_names )
276
- inputs = {i : sess .graph .get_tensor_by_name (i ) for i in keras_model .inputs }
277
- outputs = {i : sess .graph .get_tensor_by_name (i ) for i in keras_model .outputs }
278
286
tf_reset_default_graph ()
279
287
with tf_session () as sess :
280
- frozen_graph = tf_optimize (inputs , outputs , frozen_graph )
288
+ frozen_graph = tf_optimize (input_names , output_names , frozen_graph )
281
289
tf_reset_default_graph ()
282
- return frozen_graph , inputs , outputs
290
+ return frozen_graph , input_names , output_names
283
291
284
292
285
- def tf_optimize_grappler (input_tensors , output_tensors , graph_def , fold_constant = None ):
293
+ def tf_optimize_grappler (input_names , output_names , graph_def , fold_constant = None ):
286
294
from tensorflow .core .protobuf import meta_graph_pb2 as meta_graph_pb2 , config_pb2 , rewriter_config_pb2
287
295
from tensorflow .python .grappler import tf_optimizer as tf_opt
288
296
289
- # don't use resource type as input
290
- output_tensors = list (output_tensors )
291
-
292
297
config = config_pb2 .ConfigProto ()
293
298
rewrite_options = config .graph_options .rewrite_options
294
299
config .graph_options .infer_shapes = True
@@ -300,34 +305,27 @@ def tf_optimize_grappler(input_tensors, output_tensors, graph_def, fold_constant
300
305
]
301
306
meta_graph = tf .compat .v1 .train .export_meta_graph (graph_def = graph_def )
302
307
fetch_collection = meta_graph_pb2 .CollectionDef ()
303
- for t in list ( input_tensors ) + output_tensors :
308
+ for t in input_names + output_names :
304
309
fetch_collection .node_list .value .append (t )
305
310
meta_graph .collection_def ["train_op" ].CopyFrom (fetch_collection )
306
311
graph_def = tf_opt .OptimizeGraph (config , meta_graph )
307
312
return graph_def
308
313
309
314
310
- def tf_optimize (input_tensors , output_tensors , graph_def , fold_constant = True ):
315
+ def tf_optimize (input_names , output_names , graph_def , fold_constant = True ):
311
316
"""Extract inference subgraph and optimize graph."""
312
- assert isinstance (input_tensors , dict )
313
- assert isinstance (output_tensors , dict )
314
- try :
315
- input_tensors = {
316
- name : tensor for name , tensor in input_tensors .items ()
317
- if tensor .dtype != tf .dtypes .resource
318
- }
319
- except : # pylint: disable=bare-except
320
- pass
317
+ assert isinstance (input_names , list )
318
+ assert isinstance (output_names , list )
321
319
322
320
# TODO: is this needed ?
323
- needed_names = [utils .node_name (i ) for i in input_tensors . keys () ] + \
324
- [utils .node_name (i ) for i in output_tensors . keys () ]
321
+ needed_names = [utils .node_name (i ) for i in input_names ] + \
322
+ [utils .node_name (i ) for i in output_names ]
325
323
graph_def = extract_sub_graph (graph_def , needed_names )
326
324
327
325
if fold_constant :
328
326
want_grappler = is_tf2 () or LooseVersion (tf .__version__ ) >= "1.15"
329
327
if want_grappler :
330
- graph_def = tf_optimize_grappler (input_tensors , output_tensors , graph_def , fold_constant )
328
+ graph_def = tf_optimize_grappler (input_names , output_names , graph_def , fold_constant )
331
329
else :
332
330
# the older transform path
333
331
from tensorflow .tools .graph_transforms import TransformGraph # pylint: disable=redefined-outer-name
@@ -341,7 +339,7 @@ def tf_optimize(input_tensors, output_tensors, graph_def, fold_constant=True):
341
339
"fold_batch_norms" ,
342
340
"fold_old_batch_norms" ,
343
341
])
344
- graph_def = TransformGraph (graph_def , input_tensors . keys (), output_tensors . keys () , transforms )
342
+ graph_def = TransformGraph (graph_def , input_names , output_names , transforms )
345
343
346
344
return graph_def
347
345
@@ -367,7 +365,6 @@ def is_function(g):
367
365
return 'tensorflow.python.framework.func_graph.FuncGraph' in str (type (g ))
368
366
return False
369
367
370
-
371
368
_FUNCTIONS = {}
372
369
373
370
@@ -387,7 +384,12 @@ def toposort(data):
387
384
fdef = fdef .definition
388
385
if input_shapes and len (fdef .signature .input_arg ) < len (input_shapes ):
389
386
input_shapes = input_shapes [:len (fdef .signature .input_arg )]
390
- func = function_def_to_graph (fdef , input_shapes = input_shapes )
387
+ try :
388
+ func = function_def_to_graph (fdef , input_shapes = input_shapes )
389
+ except : # pylint: disable=bare-except
390
+ # if there is a missmatch between caller and function use the functions shape
391
+ logger .warning ("shape missmatch between caller and function: %s" , k )
392
+ func = function_def_to_graph (fdef )
391
393
_FUNCTIONS [k ] = func
392
394
_ , _ , _ , _ , _ , tfunctions = tflist_to_onnx (func , {})
393
395
functions .update (tfunctions )
0 commit comments