Skip to content

Commit 9465083

Browse files
committed
Add documentation, fix UT.
1 parent daf85d5 commit 9465083

File tree

4 files changed

+35
-6
lines changed

4 files changed

+35
-6
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ python -m tf2onnx.convert
139139
[--outputs GRAPH_OUTPUS]
140140
[--inputs-as-nchw inputs_provided_as_nchw]
141141
[--opset OPSET]
142+
[--tag TAG]
143+
[--signature_def SIGNATURE_DEF]
144+
[--concrete_function CONCRETE_FUNCTION]
142145
[--target TARGET]
143146
[--custom-ops list-of-custom-ops]
144147
[--fold_const]
@@ -176,6 +179,20 @@ By default we preserve the image format of inputs (`nchw` or `nhwc`) as given in
176179

177180
By default we use the opset 8 to generate the graph. By specifying ```--opset``` the user can override the default to generate a graph with the desired opset. For example ```--opset 5``` would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.
178181

182+
#### --tag
183+
184+
Only valid with parameter `--saved_model`. Specifies the tag in the saved_model to be used. Typical value is 'serve'.
185+
186+
#### --signature_def
187+
188+
Only valid with parameter `--saved_model`. Specifies which signature to use within the specified --tag value. Typical value is 'serving_default'.
189+
190+
#### --concrete_function
191+
192+
(This is experimental, valid only for TF2.x models)
193+
194+
Only valid with parameter `--saved_model`. If a model contains a list of concrete functions, under the function name `__call__` (as can be viewed using the command `saved_model_cli show --all`), this parameter is a 0-based integer specifying which function in that list should be converted. This parameter takes priority over `--signature_def`, which will be ignored.
195+
179196
#### --target
180197

181198
Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value.

tests/run_pretrained_models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,18 @@ def get_ones(shape):
7979
"""Get ones."""
8080
return np.ones(shape).astype(np.float32)
8181

82+
def get_zeros(shape):
83+
"""Get zeros."""
84+
return np.zeros(shape).astype(np.float32)
85+
8286

8387
_INPUT_FUNC_MAPPING = {
8488
"get_beach": get_beach,
8589
"get_random": get_random,
8690
"get_random256": get_random256,
8791
"get_ramp": get_ramp,
88-
"get_ones": get_ones
92+
"get_ones": get_ones,
93+
"get_zeros": get_zeros,
8994
}
9095

9196
OpsetConstraint = namedtuple("OpsetConstraint", "domain, min_version, max_version, excluded_version")
@@ -100,7 +105,7 @@ class Test(object):
100105
def __init__(self, url, local, make_input, input_names, output_names,
101106
disabled=False, rtol=0.01, atol=1e-6,
102107
check_only_shape=False, model_type="frozen", force_input_shape=False,
103-
skip_tensorflow=False, opset_constraints=None, tf_min_version=None):
108+
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None):
104109
self.url = url
105110
self.make_input = make_input
106111
self.local = local
@@ -114,6 +119,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
114119
self.tf_runtime = 0
115120
self.onnx_runtime = 0
116121
self.model_type = model_type
122+
self.tag = tag
117123
self.force_input_shape = force_input_shape
118124
self.skip_tensorflow = skip_tensorflow
119125
self.opset_constraints = opset_constraints
@@ -240,7 +246,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
240246
if self.model_type in ["checkpoint"]:
241247
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
242248
elif self.model_type in ["saved_model"]:
243-
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs)
249+
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
244250
elif self.model_type in ["keras"]:
245251
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
246252
else:
@@ -436,7 +442,7 @@ def load_tests_from_yaml(path):
436442

437443
kwargs = {}
438444
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type",
439-
"skip_tensorflow", "force_input_shape", "tf_min_version"]:
445+
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag"]:
440446
if settings.get(kw) is not None:
441447
kwargs[kw] = settings[kw]
442448

tests/run_pretrained_models.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ regression-checkpoint:
2121
regression-saved-model:
2222
model: models/regression/saved_model
2323
model_type: saved_model
24+
tag: serve
2425
input_get: get_ramp
2526
inputs:
2627
"X:0": [1]
@@ -239,9 +240,10 @@ vgg-16:
239240

240241
resnet50_v2_nchw: # NOTE: Tensorflow 1.9.0 fails
241242
skip_tensorflow: true # tensorflow fails: Default MaxPoolingOp only supports NHWC on device type CPU
242-
model_type: saved_model
243243
url: http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NCHW.tar.gz
244244
model: resnet_v2_fp32_savedmodel_NCHW/1538687196
245+
model_type: saved_model
246+
tag: serve
245247
input_get: get_beach
246248
inputs:
247249
"input_tensor:0": [64, 224, 224, 3]
@@ -250,9 +252,10 @@ resnet50_v2_nchw: # NOTE: Tensorflow 1.9.0 fails
250252
- softmax_tensor:0
251253

252254
resnet50_v2_nhwc:
253-
model_type: saved_model
254255
url: http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC.tar.gz
255256
model: resnet_v2_fp32_savedmodel_NHWC/1538687283
257+
model_type: saved_model
258+
tag: serve
256259
input_get: get_beach
257260
inputs:
258261
"input_tensor:0": [64, 224, 224, 3]

tf2onnx/tf_loader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
184184
if tag is None:
185185
tag = [tf.saved_model.tag_constants.SERVING]
186186

187+
if not isinstance(tag, list):
188+
tag = [tag]
189+
187190
imported = tf.saved_model.loader.load(sess, tag, model_path)
188191
for k in imported.signature_def.keys():
189192
if k.startswith("_"):

0 commit comments

Comments
 (0)