Skip to content

Commit b40f91d

Browse files
committed
Add support for TFHub TF2.x saved_models, and --tag parameter
1 parent f407c31 commit b40f91d

File tree

2 files changed

+58
-28
lines changed

2 files changed

+58
-28
lines changed

tf2onnx/convert.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from tf2onnx import constants, logging, utils, optimizer
2424
from tf2onnx import tf_loader
2525

26-
2726
# pylint: disable=unused-argument
2827

2928
_HELP_TEXT = """
@@ -48,7 +47,10 @@ def get_args():
4847
parser.add_argument("--input", help="input from graphdef")
4948
parser.add_argument("--graphdef", help="input from graphdef")
5049
parser.add_argument("--saved-model", help="input from saved model")
51-
parser.add_argument("--signature_def", help="signature_def from saved model to use")
50+
parser.add_argument("--tag", help="tag to use for saved_model")
51+
parser.add_argument("--signature_def", help="signature_def from saved_model to use")
52+
parser.add_argument("--concrete_function", type=int, default=None,
53+
help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)")
5254
parser.add_argument("--checkpoint", help="input from checkpoint")
5355
parser.add_argument("--keras", help="input from keras model")
5456
parser.add_argument("--output", help="output model file")
@@ -127,7 +129,7 @@ def main():
127129
model_path = args.checkpoint
128130
if args.saved_model:
129131
graph_def, inputs, outputs = tf_loader.from_saved_model(
130-
args.saved_model, args.inputs, args.outputs, args.signature_def)
132+
args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function)
131133
model_path = args.saved_model
132134
if args.keras:
133135
graph_def, inputs, outputs = tf_loader.from_keras(

tf2onnx/tf_loader.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,13 @@ def from_checkpoint(model_path, input_names, output_names):
178178
return frozen_graph, input_names, output_names
179179

180180

181-
def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures):
181+
def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures):
182182
"""Load tensorflow graph from saved_model."""
183183

184-
imported = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)
184+
if tag is None:
185+
tag = [tf.saved_model.tag_constants.SERVING]
186+
187+
imported = tf.saved_model.loader.load(sess, tag, model_path)
185188
for k in imported.signature_def.keys():
186189
if k.startswith("_"):
187190
# consider signatures starting with '_' private
@@ -209,43 +212,67 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures
209212
return frozen_graph, input_names, output_names
210213

211214

212-
def _from_saved_model_v2(model_path, input_names, output_names, signatures):
215+
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, concrete_function_index):
213216
"""Load tensorflow graph from saved_model."""
214-
imported = tf.saved_model.load(model_path) # pylint: disable=no-value-for-parameter
215217

216-
# f = meta_graph_def.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
217-
for k in imported.signatures.keys():
218-
if k.startswith("_"):
219-
# consider signatures starting with '_' private
220-
continue
221-
signatures.append(k)
222-
for k in signatures:
223-
concrete_func = imported.signatures[k]
224-
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
225-
if input_tensor.dtype != tf.dtypes.resource]
226-
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
227-
if output_tensor.dtype != tf.dtypes.resource]
218+
wrn_no_tag = "'--tag' not specified for saved_model. Using empty tag [[]]"
219+
wrn_sig_1 = "'--signature_def' not specified, using first signature: %s"
220+
err_many_sig = "Cannot load multiple signature defs in TF2.x: %s"
221+
err_no_call = "Model doesn't contain usable concrete functions under __call__. Try --signature-def instead."
222+
err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
223+
err_no_sig = "No signatures found in model. Try --concrete_function instead."
224+
err_sig_nomatch = "Specified signature not in model %s"
225+
226+
if tag is None:
227+
tag = [[]]
228+
logger.warning(wrn_no_tag)
229+
utils.make_sure(len(signature_def) < 2, err_many_sig, str(signature_def))
230+
imported = tf.saved_model.load(model_path, tags=tag) # pylint: disable=no-value-for-parameter
231+
232+
all_sigs = imported.signatures.keys()
233+
valid_sigs = [s for s in all_sigs if not s.startswith("_")]
234+
logger.info("Signatures found in model: %s", "[" + ",".join(valid_sigs) + "].")
235+
236+
concrete_func = None
237+
if concrete_function_index is not None:
238+
utils.make_sure(hasattr(imported, "__call__"), err_no_call)
239+
utils.make_sure(concrete_function_index < len(imported.__call__.concrete_functions),
240+
err_index, concrete_function_index, len(imported.__call__.concrete_functions) - 1)
241+
sig = imported.__call__.concrete_functions[concrete_function_index].structured_input_signature[0][0]
242+
concrete_func = imported.__call__.get_concrete_function(sig)
243+
elif signature_def:
244+
utils.make_sure(signature_def[0] in valid_sigs, err_sig_nomatch, signature_def[0])
245+
concrete_func = imported.signatures[signature_def[0]]
246+
else:
247+
utils.make_sure(len(valid_sigs) > 0, err_no_sig)
248+
logger.warning(wrn_sig_1, valid_sigs[0])
249+
concrete_func = imported.signatures[valid_sigs[0]]
228250

229-
frozen_graph = from_function(concrete_func, input_names, output_names)
230-
return frozen_graph, input_names, output_names
251+
inputs = [tensor.name for tensor in concrete_func.inputs if tensor.dtype != tf.dtypes.resource]
252+
outputs = [tensor.name for tensor in concrete_func.outputs if tensor.dtype != tf.dtypes.resource]
231253

254+
# filter by user specified inputs/outputs
255+
if input_names:
256+
inputs = list(set(input_names) & set(inputs))
257+
if output_names:
258+
outputs = list(set(output_names) & set(outputs))
232259

233-
def from_saved_model(model_path, input_names, output_names, signatures=None):
260+
frozen_graph = from_function(concrete_func, inputs, outputs)
261+
return frozen_graph, inputs, outputs
262+
263+
264+
def from_saved_model(model_path, input_names, output_names, tag=None, signatures=None, concrete_function=None):
234265
"""Load tensorflow graph from saved_model."""
235266
if signatures is None:
236267
signatures = []
237268
tf_reset_default_graph()
238269
if is_tf2():
239270
frozen_graph, input_names, output_names = \
240-
_from_saved_model_v2(model_path, input_names, output_names, signatures)
271+
_from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function)
241272
else:
242273
with tf_session() as sess:
243274
frozen_graph, input_names, output_names = \
244-
_from_saved_model_v1(sess, model_path, input_names, output_names, signatures)
245-
246-
if len(signatures) > 1:
247-
logger.warning("found multiple signatures %s in saved_model, pass --signature_def in command line",
248-
signatures)
275+
_from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures)
249276

250277
tf_reset_default_graph()
251278
return frozen_graph, input_names, output_names
@@ -366,6 +393,7 @@ def is_function(g):
366393
return 'tensorflow.python.framework.func_graph.FuncGraph' in str(type(g))
367394
return False
368395

396+
369397
_FUNCTIONS = {}
370398

371399

0 commit comments

Comments
 (0)