Skip to content

Commit bba23d2

Browse files
authored
Merge pull request #984 from jignparm/jignparm/savedmodelv2
Add support for TF2.x saved_models from TFHub, as well as --tag & -concrete_function cmd line parameters
2 parents f407c31 + 9465083 commit bba23d2

File tree

6 files changed

+95
-34
lines changed

6 files changed

+95
-34
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ python -m tf2onnx.convert
139139
[--outputs GRAPH_OUTPUS]
140140
[--inputs-as-nchw inputs_provided_as_nchw]
141141
[--opset OPSET]
142+
[--tag TAG]
143+
[--signature_def SIGNATURE_DEF]
144+
[--concrete_function CONCRETE_FUNCTION]
142145
[--target TARGET]
143146
[--custom-ops list-of-custom-ops]
144147
[--fold_const]
@@ -176,6 +179,20 @@ By default we preserve the image format of inputs (`nchw` or `nhwc`) as given in
176179

177180
By default we use the opset 8 to generate the graph. By specifying ```--opset``` the user can override the default to generate a graph with the desired opset. For example ```--opset 5``` would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.
178181

182+
#### --tag
183+
184+
Only valid with parameter `--saved_model`. Specifies the tag in the saved_model to be used. Typical value is 'serve'.
185+
186+
#### --signature_def
187+
188+
Only valid with parameter `--saved_model`. Specifies which signature to use within the specified --tag value. Typical value is 'serving_default'.
189+
190+
#### --concrete_function
191+
192+
(This is experimental, valid only for TF2.x models)
193+
194+
Only valid with parameter `--saved_model`. If a model contains a list of concrete functions, under the function name `__call__` (as can be viewed using the command `saved_model_cli show --all`), this parameter is a 0-based integer specifying which function in that list should be converted. This parameter takes priority over `--signature_def`, which will be ignored.
195+
179196
#### --target
180197

181198
Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value.

tests/run_pretrained_models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,18 @@ def get_ones(shape):
7979
"""Get ones."""
8080
return np.ones(shape).astype(np.float32)
8181

82+
def get_zeros(shape):
83+
"""Get zeros."""
84+
return np.zeros(shape).astype(np.float32)
85+
8286

8387
_INPUT_FUNC_MAPPING = {
8488
"get_beach": get_beach,
8589
"get_random": get_random,
8690
"get_random256": get_random256,
8791
"get_ramp": get_ramp,
88-
"get_ones": get_ones
92+
"get_ones": get_ones,
93+
"get_zeros": get_zeros,
8994
}
9095

9196
OpsetConstraint = namedtuple("OpsetConstraint", "domain, min_version, max_version, excluded_version")
@@ -100,7 +105,7 @@ class Test(object):
100105
def __init__(self, url, local, make_input, input_names, output_names,
101106
disabled=False, rtol=0.01, atol=1e-6,
102107
check_only_shape=False, model_type="frozen", force_input_shape=False,
103-
skip_tensorflow=False, opset_constraints=None, tf_min_version=None):
108+
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None):
104109
self.url = url
105110
self.make_input = make_input
106111
self.local = local
@@ -114,6 +119,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
114119
self.tf_runtime = 0
115120
self.onnx_runtime = 0
116121
self.model_type = model_type
122+
self.tag = tag
117123
self.force_input_shape = force_input_shape
118124
self.skip_tensorflow = skip_tensorflow
119125
self.opset_constraints = opset_constraints
@@ -240,7 +246,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
240246
if self.model_type in ["checkpoint"]:
241247
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
242248
elif self.model_type in ["saved_model"]:
243-
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs)
249+
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
244250
elif self.model_type in ["keras"]:
245251
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
246252
else:
@@ -436,7 +442,7 @@ def load_tests_from_yaml(path):
436442

437443
kwargs = {}
438444
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type",
439-
"skip_tensorflow", "force_input_shape", "tf_min_version"]:
445+
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag"]:
440446
if settings.get(kw) is not None:
441447
kwargs[kw] = settings[kw]
442448

tests/run_pretrained_models.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ regression-checkpoint:
2121
regression-saved-model:
2222
model: models/regression/saved_model
2323
model_type: saved_model
24+
tag: serve
2425
input_get: get_ramp
2526
inputs:
2627
"X:0": [1]
@@ -239,9 +240,10 @@ vgg-16:
239240

240241
resnet50_v2_nchw: # NOTE: Tensorflow 1.9.0 fails
241242
skip_tensorflow: true # tensorflow fails: Default MaxPoolingOp only supports NHWC on device type CPU
242-
model_type: saved_model
243243
url: http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NCHW.tar.gz
244244
model: resnet_v2_fp32_savedmodel_NCHW/1538687196
245+
model_type: saved_model
246+
tag: serve
245247
input_get: get_beach
246248
inputs:
247249
"input_tensor:0": [64, 224, 224, 3]
@@ -250,9 +252,10 @@ resnet50_v2_nchw: # NOTE: Tensorflow 1.9.0 fails
250252
- softmax_tensor:0
251253

252254
resnet50_v2_nhwc:
253-
model_type: saved_model
254255
url: http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC.tar.gz
255256
model: resnet_v2_fp32_savedmodel_NHWC/1538687283
257+
model_type: saved_model
258+
tag: serve
256259
input_get: get_beach
257260
inputs:
258261
"input_tensor:0": [64, 224, 224, 3]

tests/test_convert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def test_convert_saved_model(self):
2828
self.assertTrue(run_test_case(['',
2929
'--saved-model',
3030
'tests/models/regression/saved_model',
31+
'--tag',
32+
'serve',
3133
'--output',
3234
'converted_saved_model.onnx']))
3335

tf2onnx/convert.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from tf2onnx import constants, logging, utils, optimizer
2424
from tf2onnx import tf_loader
2525

26-
2726
# pylint: disable=unused-argument
2827

2928
_HELP_TEXT = """
@@ -48,7 +47,10 @@ def get_args():
4847
parser.add_argument("--input", help="input from graphdef")
4948
parser.add_argument("--graphdef", help="input from graphdef")
5049
parser.add_argument("--saved-model", help="input from saved model")
51-
parser.add_argument("--signature_def", help="signature_def from saved model to use")
50+
parser.add_argument("--tag", help="tag to use for saved_model")
51+
parser.add_argument("--signature_def", help="signature_def from saved_model to use")
52+
parser.add_argument("--concrete_function", type=int, default=None,
53+
help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)")
5254
parser.add_argument("--checkpoint", help="input from checkpoint")
5355
parser.add_argument("--keras", help="input from keras model")
5456
parser.add_argument("--output", help="output model file")
@@ -127,7 +129,7 @@ def main():
127129
model_path = args.checkpoint
128130
if args.saved_model:
129131
graph_def, inputs, outputs = tf_loader.from_saved_model(
130-
args.saved_model, args.inputs, args.outputs, args.signature_def)
132+
args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function)
131133
model_path = args.saved_model
132134
if args.keras:
133135
graph_def, inputs, outputs = tf_loader.from_keras(

tf2onnx/tf_loader.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,16 @@ def from_checkpoint(model_path, input_names, output_names):
178178
return frozen_graph, input_names, output_names
179179

180180

181-
def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures):
181+
def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures):
182182
"""Load tensorflow graph from saved_model."""
183183

184-
imported = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)
184+
if tag is None:
185+
tag = [tf.saved_model.tag_constants.SERVING]
186+
187+
if not isinstance(tag, list):
188+
tag = [tag]
189+
190+
imported = tf.saved_model.loader.load(sess, tag, model_path)
185191
for k in imported.signature_def.keys():
186192
if k.startswith("_"):
187193
# consider signatures starting with '_' private
@@ -209,43 +215,67 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures
209215
return frozen_graph, input_names, output_names
210216

211217

212-
def _from_saved_model_v2(model_path, input_names, output_names, signatures):
218+
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, concrete_function_index):
213219
"""Load tensorflow graph from saved_model."""
214-
imported = tf.saved_model.load(model_path) # pylint: disable=no-value-for-parameter
215220

216-
# f = meta_graph_def.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
217-
for k in imported.signatures.keys():
218-
if k.startswith("_"):
219-
# consider signatures starting with '_' private
220-
continue
221-
signatures.append(k)
222-
for k in signatures:
223-
concrete_func = imported.signatures[k]
224-
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
225-
if input_tensor.dtype != tf.dtypes.resource]
226-
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
227-
if output_tensor.dtype != tf.dtypes.resource]
221+
wrn_no_tag = "'--tag' not specified for saved_model. Using empty tag [[]]"
222+
wrn_sig_1 = "'--signature_def' not specified, using first signature: %s"
223+
err_many_sig = "Cannot load multiple signature defs in TF2.x: %s"
224+
err_no_call = "Model doesn't contain usable concrete functions under __call__. Try --signature-def instead."
225+
err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
226+
err_no_sig = "No signatures found in model. Try --concrete_function instead."
227+
err_sig_nomatch = "Specified signature not in model %s"
228+
229+
if tag is None:
230+
tag = [[]]
231+
logger.warning(wrn_no_tag)
232+
utils.make_sure(len(signature_def) < 2, err_many_sig, str(signature_def))
233+
imported = tf.saved_model.load(model_path, tags=tag) # pylint: disable=no-value-for-parameter
234+
235+
all_sigs = imported.signatures.keys()
236+
valid_sigs = [s for s in all_sigs if not s.startswith("_")]
237+
logger.info("Signatures found in model: %s", "[" + ",".join(valid_sigs) + "].")
238+
239+
concrete_func = None
240+
if concrete_function_index is not None:
241+
utils.make_sure(hasattr(imported, "__call__"), err_no_call)
242+
utils.make_sure(concrete_function_index < len(imported.__call__.concrete_functions),
243+
err_index, concrete_function_index, len(imported.__call__.concrete_functions) - 1)
244+
sig = imported.__call__.concrete_functions[concrete_function_index].structured_input_signature[0][0]
245+
concrete_func = imported.__call__.get_concrete_function(sig)
246+
elif signature_def:
247+
utils.make_sure(signature_def[0] in valid_sigs, err_sig_nomatch, signature_def[0])
248+
concrete_func = imported.signatures[signature_def[0]]
249+
else:
250+
utils.make_sure(len(valid_sigs) > 0, err_no_sig)
251+
logger.warning(wrn_sig_1, valid_sigs[0])
252+
concrete_func = imported.signatures[valid_sigs[0]]
228253

229-
frozen_graph = from_function(concrete_func, input_names, output_names)
230-
return frozen_graph, input_names, output_names
254+
inputs = [tensor.name for tensor in concrete_func.inputs if tensor.dtype != tf.dtypes.resource]
255+
outputs = [tensor.name for tensor in concrete_func.outputs if tensor.dtype != tf.dtypes.resource]
231256

257+
# filter by user specified inputs/outputs
258+
if input_names:
259+
inputs = list(set(input_names) & set(inputs))
260+
if output_names:
261+
outputs = list(set(output_names) & set(outputs))
232262

233-
def from_saved_model(model_path, input_names, output_names, signatures=None):
263+
frozen_graph = from_function(concrete_func, inputs, outputs)
264+
return frozen_graph, inputs, outputs
265+
266+
267+
def from_saved_model(model_path, input_names, output_names, tag=None, signatures=None, concrete_function=None):
234268
"""Load tensorflow graph from saved_model."""
235269
if signatures is None:
236270
signatures = []
237271
tf_reset_default_graph()
238272
if is_tf2():
239273
frozen_graph, input_names, output_names = \
240-
_from_saved_model_v2(model_path, input_names, output_names, signatures)
274+
_from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function)
241275
else:
242276
with tf_session() as sess:
243277
frozen_graph, input_names, output_names = \
244-
_from_saved_model_v1(sess, model_path, input_names, output_names, signatures)
245-
246-
if len(signatures) > 1:
247-
logger.warning("found multiple signatures %s in saved_model, pass --signature_def in command line",
248-
signatures)
278+
_from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures)
249279

250280
tf_reset_default_graph()
251281
return frozen_graph, input_names, output_names
@@ -366,6 +396,7 @@ def is_function(g):
366396
return 'tensorflow.python.framework.func_graph.FuncGraph' in str(type(g))
367397
return False
368398

399+
369400
_FUNCTIONS = {}
370401

371402

0 commit comments

Comments
 (0)