Skip to content

Commit 0d6a081

Browse files
committed
Merge branch 'master' into r1.6
2 parents 8d52538 + 38b1a6a commit 0d6a081

18 files changed

+724
-240
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ python -m tf2onnx.convert
139139
[--outputs GRAPH_OUTPUS]
140140
[--inputs-as-nchw inputs_provided_as_nchw]
141141
[--opset OPSET]
142+
[--tag TAG]
143+
[--signature_def SIGNATURE_DEF]
144+
[--concrete_function CONCRETE_FUNCTION]
142145
[--target TARGET]
143146
[--custom-ops list-of-custom-ops]
144147
[--fold_const]
@@ -176,6 +179,20 @@ By default we preserve the image format of inputs (`nchw` or `nhwc`) as given in
176179

177180
By default we use the opset 8 to generate the graph. By specifying ```--opset``` the user can override the default to generate a graph with the desired opset. For example ```--opset 5``` would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.
178181

182+
#### --tag
183+
184+
Only valid with parameter `--saved_model`. Specifies the tag in the saved_model to be used. Typical value is 'serve'.
185+
186+
#### --signature_def
187+
188+
Only valid with parameter `--saved_model`. Specifies which signature to use within the specified --tag value. Typical value is 'serving_default'.
189+
190+
#### --concrete_function
191+
192+
(This is experimental, valid only for TF2.x models)
193+
194+
Only valid with parameter `--saved_model`. If a model contains a list of concrete functions, under the function name `__call__` (as can be viewed using the command `saved_model_cli show --all`), this parameter is a 0-based integer specifying which function in that list should be converted. This parameter takes priority over `--signature_def`, which will be ignored.
195+
179196
#### --target
180197

181198
Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value.

tests/run_pretrained_models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,18 @@ def get_ones(shape):
7979
"""Get ones."""
8080
return np.ones(shape).astype(np.float32)
8181

82+
def get_zeros(shape):
83+
"""Get zeros."""
84+
return np.zeros(shape).astype(np.float32)
85+
8286

8387
_INPUT_FUNC_MAPPING = {
8488
"get_beach": get_beach,
8589
"get_random": get_random,
8690
"get_random256": get_random256,
8791
"get_ramp": get_ramp,
88-
"get_ones": get_ones
92+
"get_ones": get_ones,
93+
"get_zeros": get_zeros,
8994
}
9095

9196
OpsetConstraint = namedtuple("OpsetConstraint", "domain, min_version, max_version, excluded_version")
@@ -100,7 +105,7 @@ class Test(object):
100105
def __init__(self, url, local, make_input, input_names, output_names,
101106
disabled=False, rtol=0.01, atol=1e-6,
102107
check_only_shape=False, model_type="frozen", force_input_shape=False,
103-
skip_tensorflow=False, opset_constraints=None, tf_min_version=None):
108+
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None):
104109
self.url = url
105110
self.make_input = make_input
106111
self.local = local
@@ -114,6 +119,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
114119
self.tf_runtime = 0
115120
self.onnx_runtime = 0
116121
self.model_type = model_type
122+
self.tag = tag
117123
self.force_input_shape = force_input_shape
118124
self.skip_tensorflow = skip_tensorflow
119125
self.opset_constraints = opset_constraints
@@ -240,7 +246,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
240246
if self.model_type in ["checkpoint"]:
241247
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
242248
elif self.model_type in ["saved_model"]:
243-
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs)
249+
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
244250
elif self.model_type in ["keras"]:
245251
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
246252
else:
@@ -436,7 +442,7 @@ def load_tests_from_yaml(path):
436442

437443
kwargs = {}
438444
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type",
439-
"skip_tensorflow", "force_input_shape", "tf_min_version"]:
445+
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag"]:
440446
if settings.get(kw) is not None:
441447
kwargs[kw] = settings[kw]
442448

tests/run_pretrained_models.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ regression-checkpoint:
2121
regression-saved-model:
2222
model: models/regression/saved_model
2323
model_type: saved_model
24+
tag: serve
2425
input_get: get_ramp
2526
inputs:
2627
"X:0": [1]
@@ -239,9 +240,10 @@ vgg-16:
239240

240241
resnet50_v2_nchw: # NOTE: Tensorflow 1.9.0 fails
241242
skip_tensorflow: true # tensorflow fails: Default MaxPoolingOp only supports NHWC on device type CPU
242-
model_type: saved_model
243243
url: http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NCHW.tar.gz
244244
model: resnet_v2_fp32_savedmodel_NCHW/1538687196
245+
model_type: saved_model
246+
tag: serve
245247
input_get: get_beach
246248
inputs:
247249
"input_tensor:0": [64, 224, 224, 3]
@@ -250,9 +252,10 @@ resnet50_v2_nchw: # NOTE: Tensorflow 1.9.0 fails
250252
- softmax_tensor:0
251253

252254
resnet50_v2_nhwc:
253-
model_type: saved_model
254255
url: http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC.tar.gz
255256
model: resnet_v2_fp32_savedmodel_NHWC/1538687283
257+
model_type: saved_model
258+
tag: serve
256259
input_get: get_beach
257260
inputs:
258261
"input_tensor:0": [64, 224, 224, 3]

tests/test_backend.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from itertools import product
1515

1616
import numpy as np
17+
from numpy.testing import assert_almost_equal
1718
import tensorflow as tf
1819

1920
from tensorflow.python.ops import lookup_ops
@@ -69,6 +70,7 @@
6970
is_inf = tf.math.is_inf
7071
floormod = tf.math.floormod
7172
matrix_diag_part = tf.compat.v1.matrix_diag_part
73+
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
7274
elif LooseVersion(tf.__version__) >= "1.13":
7375
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
7476
multinomial = tf.compat.v1.random.multinomial
@@ -88,6 +90,7 @@
8890
is_inf = tf.math.is_inf
8991
floormod = tf.floormod
9092
matrix_diag_part = tf.compat.v1.matrix_diag_part
93+
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
9194
else:
9295
conv2d_backprop_input = tf.nn.conv2d_backprop_input
9396
multinomial = tf.multinomial
@@ -3352,6 +3355,65 @@ def func(base_matrix, diag, k):
33523355

33533356
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val, _INPUT1: diag_val, _INPUT2: k_val})
33543357

3358+
@check_opset_min_version(10)
3359+
@check_tf_min_version("1.14")
3360+
def test_fakequant_with_min_max(self):
3361+
def func(x):
3362+
ret = fake_quant_with_min_max_args(
3363+
x, min=-1024, max=1023, num_bits=8, narrow_range=False, name=None)
3364+
return tf.identity(ret, name=_TFOUTPUT)
3365+
3366+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024.
3367+
x_val0 = np.abs(x_val)
3368+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0}, rtol=1e-6, atol=1e-4)
3369+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
3370+
3371+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024
3372+
x_val[0, 0] = -1024
3373+
x_val[0, 1] = -1023
3374+
x_val[0, 2] = 1024
3375+
x_val[1, 0] = 1023
3376+
x_val[1, 1] = 1025
3377+
x_val[1, 2] = -1025
3378+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
3379+
3380+
@check_opset_min_version(10)
3381+
@check_tf_min_version("1.14")
3382+
def test_fakequant_with_min_max_same_sign(self):
3383+
def func_neg(x):
3384+
ret = fake_quant_with_min_max_args(
3385+
x, min=-1024*3, max=-1024, num_bits=8, narrow_range=False, name=None)
3386+
return tf.identity(ret, name=_TFOUTPUT)
3387+
3388+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024 * 3.
3389+
try:
3390+
self._run_test_case(func_neg, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
3391+
except ValueError:
3392+
pass
3393+
3394+
@check_opset_min_version(9, "atan2")
3395+
def test_atan2(self):
3396+
# Test all possible pairs of pos, neg, zero for x and y.
3397+
3398+
def atan2(y, x):
3399+
sx = np.sign(x)
3400+
sy = np.sign(y)
3401+
pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-np.pi/2)
3402+
atan_part = np.arctan(y / (x + (1 - sx ** 2))) * sx ** 2
3403+
return atan_part + pi_part
3404+
3405+
test_pairs = [[y, x] for x in [3., -4., 0.] for y in [5., -6., 0.]]
3406+
y_val = np.array([y for y, x in test_pairs], dtype=np.float32)
3407+
x_val = np.array([x for y, x in test_pairs], dtype=np.float32)
3408+
assert_almost_equal(np.arctan2(y_val, x_val), atan2(y_val, x_val))
3409+
3410+
def func(y, x):
3411+
atan2_ = tf.math.atan2(y, x)
3412+
return tf.identity(atan2_, name=_TFOUTPUT)
3413+
3414+
self._run_test_case(
3415+
func, [_OUTPUT], {_INPUT: y_val, _INPUT2: x_val}, rtol=1e-06)
3416+
33553417

33563418
if __name__ == '__main__':
33573419
unittest_main()

tests/test_convert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def test_convert_saved_model(self):
2828
self.assertTrue(run_test_case(['',
2929
'--saved-model',
3030
'tests/models/regression/saved_model',
31+
'--tag',
32+
'serve',
3133
'--output',
3234
'converted_saved_model.onnx']))
3335

0 commit comments

Comments
 (0)