@@ -178,10 +178,13 @@ def from_checkpoint(model_path, input_names, output_names):
178
178
return frozen_graph , input_names , output_names
179
179
180
180
181
- def _from_saved_model_v1 (sess , model_path , input_names , output_names , signatures ):
181
+ def _from_saved_model_v1 (sess , model_path , input_names , output_names , tag , signatures ):
182
182
"""Load tensorflow graph from saved_model."""
183
183
184
- imported = tf .saved_model .loader .load (sess , [tf .saved_model .tag_constants .SERVING ], model_path )
184
+ if tag is None :
185
+ tag = [tf .saved_model .tag_constants .SERVING ]
186
+
187
+ imported = tf .saved_model .loader .load (sess , tag , model_path )
185
188
for k in imported .signature_def .keys ():
186
189
if k .startswith ("_" ):
187
190
# consider signatures starting with '_' private
@@ -209,43 +212,67 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures
209
212
return frozen_graph , input_names , output_names
210
213
211
214
212
- def _from_saved_model_v2 (model_path , input_names , output_names , signatures ):
215
+ def _from_saved_model_v2 (model_path , input_names , output_names , tag , signature_def , concrete_function_index ):
213
216
"""Load tensorflow graph from saved_model."""
214
- imported = tf .saved_model .load (model_path ) # pylint: disable=no-value-for-parameter
215
217
216
- # f = meta_graph_def.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
217
- for k in imported .signatures .keys ():
218
- if k .startswith ("_" ):
219
- # consider signatures starting with '_' private
220
- continue
221
- signatures .append (k )
222
- for k in signatures :
223
- concrete_func = imported .signatures [k ]
224
- input_names = [input_tensor .name for input_tensor in concrete_func .inputs
225
- if input_tensor .dtype != tf .dtypes .resource ]
226
- output_names = [output_tensor .name for output_tensor in concrete_func .outputs
227
- if output_tensor .dtype != tf .dtypes .resource ]
218
+ wrn_no_tag = "'--tag' not specified for saved_model. Using empty tag [[]]"
219
+ wrn_sig_1 = "'--signature_def' not specified, using first signature: %s"
220
+ err_many_sig = "Cannot load multiple signature defs in TF2.x: %s"
221
+ err_no_call = "Model doesn't contain usable concrete functions under __call__. Try --signature-def instead."
222
+ err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
223
+ err_no_sig = "No signatures found in model. Try --concrete_function instead."
224
+ err_sig_nomatch = "Specified signature not in model %s"
225
+
226
+ if tag is None :
227
+ tag = [[]]
228
+ logger .warning (wrn_no_tag )
229
+ utils .make_sure (len (signature_def ) < 2 , err_many_sig , str (signature_def ))
230
+ imported = tf .saved_model .load (model_path , tags = tag ) # pylint: disable=no-value-for-parameter
231
+
232
+ all_sigs = imported .signatures .keys ()
233
+ valid_sigs = [s for s in all_sigs if not s .startswith ("_" )]
234
+ logger .info ("Signatures found in model: %s" , "[" + "," .join (valid_sigs ) + "]." )
235
+
236
+ concrete_func = None
237
+ if concrete_function_index is not None :
238
+ utils .make_sure (hasattr (imported , "__call__" ), err_no_call )
239
+ utils .make_sure (concrete_function_index < len (imported .__call__ .concrete_functions ),
240
+ err_index , concrete_function_index , len (imported .__call__ .concrete_functions ) - 1 )
241
+ sig = imported .__call__ .concrete_functions [concrete_function_index ].structured_input_signature [0 ][0 ]
242
+ concrete_func = imported .__call__ .get_concrete_function (sig )
243
+ elif signature_def :
244
+ utils .make_sure (signature_def [0 ] in valid_sigs , err_sig_nomatch , signature_def [0 ])
245
+ concrete_func = imported .signatures [signature_def [0 ]]
246
+ else :
247
+ utils .make_sure (len (valid_sigs ) > 0 , err_no_sig )
248
+ logger .warning (wrn_sig_1 , valid_sigs [0 ])
249
+ concrete_func = imported .signatures [valid_sigs [0 ]]
228
250
229
- frozen_graph = from_function ( concrete_func , input_names , output_names )
230
- return frozen_graph , input_names , output_names
251
+ inputs = [ tensor . name for tensor in concrete_func . inputs if tensor . dtype != tf . dtypes . resource ]
252
+ outputs = [ tensor . name for tensor in concrete_func . outputs if tensor . dtype != tf . dtypes . resource ]
231
253
254
+ # filter by user specified inputs/outputs
255
+ if input_names :
256
+ inputs = list (set (input_names ) & set (inputs ))
257
+ if output_names :
258
+ outputs = list (set (output_names ) & set (outputs ))
232
259
233
- def from_saved_model (model_path , input_names , output_names , signatures = None ):
260
+ frozen_graph = from_function (concrete_func , inputs , outputs )
261
+ return frozen_graph , inputs , outputs
262
+
263
+
264
+ def from_saved_model (model_path , input_names , output_names , tag = None , signatures = None , concrete_function = None ):
234
265
"""Load tensorflow graph from saved_model."""
235
266
if signatures is None :
236
267
signatures = []
237
268
tf_reset_default_graph ()
238
269
if is_tf2 ():
239
270
frozen_graph , input_names , output_names = \
240
- _from_saved_model_v2 (model_path , input_names , output_names , signatures )
271
+ _from_saved_model_v2 (model_path , input_names , output_names , tag , signatures , concrete_function )
241
272
else :
242
273
with tf_session () as sess :
243
274
frozen_graph , input_names , output_names = \
244
- _from_saved_model_v1 (sess , model_path , input_names , output_names , signatures )
245
-
246
- if len (signatures ) > 1 :
247
- logger .warning ("found multiple signatures %s in saved_model, pass --signature_def in command line" ,
248
- signatures )
275
+ _from_saved_model_v1 (sess , model_path , input_names , output_names , tag , signatures )
249
276
250
277
tf_reset_default_graph ()
251
278
return frozen_graph , input_names , output_names
@@ -366,6 +393,7 @@ def is_function(g):
366
393
return 'tensorflow.python.framework.func_graph.FuncGraph' in str (type (g ))
367
394
return False
368
395
396
+
369
397
_FUNCTIONS = {}
370
398
371
399
0 commit comments