@@ -180,9 +180,6 @@ def from_checkpoint(model_path, input_names, output_names):
180
180
181
181
def _from_saved_model_v1 (sess , model_path , input_names , output_names , signatures ):
182
182
"""Load tensorflow graph from saved_model."""
183
- # make sure we start with clean default graph
184
- inputs = {}
185
- outputs = {}
186
183
187
184
imported = tf .saved_model .loader .load (sess , [tf .saved_model .tag_constants .SERVING ], model_path )
188
185
for k in imported .signature_def .keys ():
@@ -199,17 +196,17 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures
199
196
# TF1.12 changed the api
200
197
get_signature_def = lambda meta_graph_def , k : meta_graph_def .signature_def [k ]
201
198
199
+ input_names = []
200
+ output_names = []
202
201
for k in signatures :
203
202
inputs_tensor_info = get_signature_def (imported , k ).inputs
204
203
for _ , input_tensor in inputs_tensor_info .items ():
205
- inputs [input_tensor .name ] = sess .graph .get_tensor_by_name (input_tensor .name )
204
+ if input_tensor .dtype != tf .dtypes .resource :
205
+ input_names .append (input_tensor .name )
206
206
outputs_tensor_info = get_signature_def (imported , k ).outputs
207
207
for _ , output_tensor in outputs_tensor_info .items ():
208
- outputs [output_tensor .name ] = sess .graph .get_tensor_by_name (output_tensor .name )
209
-
210
- input_names = list (inputs .keys ())
211
- input_names = inputs_without_resource (sess , input_names )
212
- output_names = list (outputs .keys ())
208
+ if output_tensor .dtype != tf .dtypes .resource :
209
+ output_names .append (output_tensor .name )
213
210
frozen_graph = freeze_session (sess , input_names = input_names , output_names = output_names )
214
211
return frozen_graph , input_names , output_names
215
212
@@ -226,8 +223,10 @@ def _from_saved_model_v2(model_path, input_names, output_names, signatures):
226
223
signatures .append (k )
227
224
for k in signatures :
228
225
concrete_func = imported .signatures [k ]
229
- input_names = [input_tensor .name for input_tensor in concrete_func .inputs ]
230
- output_names = [output_tensor .name for output_tensor in concrete_func .outputs ]
226
+ input_names = [input_tensor .name for input_tensor in concrete_func .inputs
227
+ if input_tensor .dtype != tf .dtypes .resource ]
228
+ output_names = [output_tensor .name for output_tensor in concrete_func .outputs
229
+ if output_tensor .dtype != tf .dtypes .resource ]
231
230
232
231
frozen_graph = from_function (concrete_func , input_names , output_names )
233
232
return frozen_graph , input_names , output_names
@@ -270,25 +269,21 @@ def from_keras(model_path, input_names, output_names):
270
269
concrete_func = function .get_concrete_function ()
271
270
# allow to pass inputs and outputs from caller if we don't want all of them
272
271
input_names = [input_tensor .name for input_tensor in concrete_func .inputs
273
- if input_names is None or input_tensor .name in input_names ]
272
+ if input_tensor .dtype != tf . dtypes . resource ]
274
273
output_names = [output_tensor .name for output_tensor in concrete_func .outputs
275
- if output_names is None or output_tensor .name in output_names ]
274
+ if output_tensor .dtype != tf .dtypes .resource ]
275
+
276
276
frozen_graph = from_function (concrete_func , input_names , output_names )
277
277
else :
278
278
# Handles Keras when Eager mode is disabled.
279
279
_keras .backend .clear_session ()
280
280
_keras .backend .set_learning_phase (False )
281
281
keras_model = _keras .models .load_model (model_path , custom_objects )
282
282
# allow to pass inputs and outputs from caller if we don't want all of them
283
- if input_names :
284
- input_names = [i for i in keras_model .inputs if i in input_names ]
285
- else :
286
- input_names = keras_model .inputs
287
- if output_names :
288
- output_names = [i for i in keras_model .outputs if i in output_names ]
289
- else :
290
- output_names = keras_model .outputs
283
+ input_names = keras_model .inputs
284
+ output_names = keras_model .outputs
291
285
sess = _keras .backend .get_session ()
286
+ input_names = inputs_without_resource (sess , input_names )
292
287
frozen_graph = freeze_session (sess , input_names = input_names , output_names = output_names )
293
288
tf_reset_default_graph ()
294
289
with tf_session () as sess :
0 commit comments