Skip to content

Commit 7c5340c

Browse files
authored
Merge pull request #2 from onnx/master
update
2 parents ea46335 + e59907b commit 7c5340c

27 files changed

+603
-158
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
tf2onnx - Convert TensorFlow models to ONNX.
22
========
33

4-
[![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build?definitionId=16&branchName=master)
4+
| Build Type | OS | Python | Tensorflow | Onnx opset | Status |
5+
| --- | --- | --- | --- | --- | --- |
6+
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.5, 3.6 | 1.5-1.13 | 7-10 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) |
7+
| Unit Test - Full | Linux, MacOS, Windows | 3.5, 3.6, 3.7 | 1.5-1.13 | 7-10 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | |
8+
9+
<a name="build_status_footnote">\*</a> Only test on python3.6, TF1.13.
510

611
# Supported ONNX version
712
tensorflow-onnx will use the ONNX version installed on your system and installs the latest ONNX version if none is found.

ci_build/azure_pipelines/onnxruntime_nightly_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
jobs:
44
- template: 'templates/job_generator.yml'
55
parameters:
6-
tf_versions: ['1.13.1']
6+
tf_versions: ['1.14']
77
onnx_opsets: ['']
88
onnx_backends:
99
onnxruntime: ['']

ci_build/azure_pipelines/pretrained_model_test-matrix.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ jobs:
55
parameters:
66
platforms: ['linux', 'windows', 'mac']
77
python_versions: ['3.6', '3.5']
8-
tf_versions: ['1.12', '1.11', '1.10', '1.9', '1.8', '1.7', '1.6', '1.5']
8+
tf_versions: ['1.13.1', '1.12', '1.11', '1.10', '1.9', '1.8', '1.7', '1.6', '1.5']
99
job:
1010
steps:
1111
- template: 'pretrained_model_test.yml'
@@ -14,7 +14,7 @@ jobs:
1414
parameters:
1515
platforms: ['linux', 'windows', 'mac']
1616
python_versions: ['3.7', '3.6', '3.5']
17-
tf_versions: ['1.13.1']
17+
tf_versions: ['1.14']
1818
job:
1919
steps:
2020
- template: 'pretrained_model_test.yml'

ci_build/azure_pipelines/pretrained_model_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ jobs:
44
- template: 'templates/job_generator.yml'
55
parameters:
66
python_versions: ['3.7', '3.6', '3.5']
7-
tf_versions: ['1.13.1']
7+
tf_versions: ['1.14.0']
88
job:
99
steps:
1010
- template: 'pretrained_model_test.yml'
1111

1212
- template: 'templates/job_generator.yml'
1313
parameters:
1414
platforms: ['windows', 'mac']
15-
tf_versions: ['1.13.1']
15+
tf_versions: ['1.14.0']
1616
job:
1717
steps:
1818
- template: 'pretrained_model_test.yml'

ci_build/azure_pipelines/templates/unit_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ steps:
99
export TF2ONNX_TEST_BACKEND=$CI_ONNX_BACKEND
1010
export TF2ONNX_TEST_OPSET=$CI_ONNX_OPSET
1111
python -m pytest --cov=tf2onnx --cov-report=term --disable-pytest-warnings -r s tests --cov-append
12-
timeoutInMinutes: 5
12+
timeoutInMinutes: 15
1313
displayName: ${{ format('Run UnitTest - Opset{0}', onnx_opset) }}
1414
condition: succeededOrFailed()
1515
env:

ci_build/azure_pipelines/unit_test-matrix.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ stages:
77
parameters:
88
platforms: ['linux', 'windows', 'mac']
99
python_versions: ['3.6', '3.5']
10-
tf_versions: ['1.12', '1.11', '1.10', '1.9', '1.8', '1.7', '1.6', '1.5']
10+
tf_versions: ['1.13.1','1.12', '1.11', '1.10', '1.9', '1.8', '1.7', '1.6', '1.5']
1111
onnx_opsets: ['']
1212
job:
1313
steps:
@@ -17,8 +17,8 @@ stages:
1717
- template: 'templates/job_generator.yml'
1818
parameters:
1919
platforms: ['linux', 'windows', 'mac']
20-
python_versions: ['3.7']
21-
tf_versions: ['1.13.1']
20+
python_versions: ['3.7', '3.6', '3.5']
21+
tf_versions: ['1.14']
2222
onnx_opsets: ['']
2323
job:
2424
steps:

ci_build/azure_pipelines/unit_test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ stages:
66
- template: 'templates/job_generator.yml'
77
parameters:
88
python_versions: ['3.7', '3.6', '3.5']
9-
tf_versions: ['1.13.1']
9+
tf_versions: ['1.14']
1010
onnx_opsets: ['']
1111
job:
1212
steps:
@@ -15,7 +15,7 @@ stages:
1515

1616
- template: 'templates/job_generator.yml'
1717
parameters:
18-
tf_versions: ['1.12', '1.11', '1.10', '1.9', '1.8', '1.7', '1.6', '1.5']
18+
tf_versions: ['1.13.1', '1.12', '1.11', '1.10', '1.9', '1.8', '1.7', '1.6', '1.5']
1919
onnx_opsets: ['']
2020
job:
2121
steps:
@@ -25,7 +25,7 @@ stages:
2525
- template: 'templates/job_generator.yml'
2626
parameters:
2727
platforms: ['windows', 'mac']
28-
tf_versions: ['1.13.1']
28+
tf_versions: ['1.14']
2929
onnx_opsets: ['']
3030
job:
3131
steps:

tests/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ def validate_const_node(node, expected_val):
297297
def group_nodes_by_type(graph):
298298
res = defaultdict(list)
299299
for node in graph.get_nodes():
300+
attr_body_graphs = node.get_body_graphs()
301+
if attr_body_graphs:
302+
for _, body_graph in attr_body_graphs.items():
303+
body_graph_res = group_nodes_by_type(body_graph)
304+
for k, v in body_graph_res.items():
305+
res[k].extend(v)
300306
res[node.type].append(node)
301307
return res
302308

tests/test_backend.py

Lines changed: 119 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,24 +108,17 @@ def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
108108
kwargs["constant_fold"] = False
109109
return self.run_test_case(feed_dict, [], output_names_with_port, **kwargs)
110110

111-
def _test_expand_dims(self, idx):
111+
def _test_expand_dims_known_rank(self, idx):
112112
tf.reset_default_graph()
113113
x_val = make_xval([3, 4])
114114
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
115115
op = tf.expand_dims(x, idx)
116116
_ = tf.identity(op, name=_TFOUTPUT)
117117
self._run_test_case([_OUTPUT], {_INPUT: x_val})
118118

119-
def test_expand_dims(self):
119+
def test_expand_dims_known_rank(self):
120120
for i in [-1, 0, 1, -2]:
121-
self._test_expand_dims(i)
122-
123-
def test_expand_dims_dynamic_inputs(self):
124-
x_val = make_xval([3, 4])
125-
x = tf.placeholder(tf.float32, shape=[None, None], name=_TFINPUT)
126-
op = tf.expand_dims(x, 0)
127-
_ = tf.identity(op, name=_TFOUTPUT)
128-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
121+
self._test_expand_dims_known_rank(i)
129122

130123
def test_expand_dims_one_unknown_rank(self):
131124
tf.reset_default_graph()
@@ -135,14 +128,18 @@ def test_expand_dims_one_unknown_rank(self):
135128
_ = tf.identity(op, name=_TFOUTPUT)
136129
self._run_test_case([_OUTPUT], {_INPUT: x_val})
137130

138-
def test_expand_dims_more_unknown_rank(self):
131+
def _test_expand_dims_more_unknown_rank(self, idx):
139132
tf.reset_default_graph()
140133
x_val = make_xval([3, 4])
141134
x = tf.placeholder(tf.float32, shape=[None, None], name=_TFINPUT)
142-
op = tf.expand_dims(x, 0)
135+
op = tf.expand_dims(x, idx)
143136
_ = tf.identity(op, name=_TFOUTPUT)
144137
self._run_test_case([_OUTPUT], {_INPUT: x_val})
145138

139+
def test_expand_dims_more_unknown_rank(self):
140+
for i in [-1, 0, 1, -2]:
141+
self._test_expand_dims_more_unknown_rank(i)
142+
146143
@check_opset_min_version(9, "ConstantOfShape")
147144
def test_eye_non_const1(self):
148145
# tf.eye(num_rows), num_rows is not const here
@@ -444,7 +441,9 @@ def test_dropout(self):
444441
feed_dict = {"input_1:0": x_val}
445442
input_names_with_port = ["input_1:0"]
446443
output_names_with_port = ["output:0"]
447-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
444+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port,
445+
graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and
446+
check_op_count(g, "RandomUniformLike", 0)))
448447

449448
def test_nn_dropout(self):
450449
keep_prob = tf.placeholder_with_default(1., (), "keep_prob")
@@ -461,7 +460,10 @@ def test_nn_dropout(self):
461460
output_names_with_port = ["output:0"]
462461
# when constant_fold is enabled, PlaceholderWithDefault will be folded into either a const or a placeholder.
463462
# here we set it False to test PlaceholderWithDefault bug: https://github.com/onnx/tensorflow-onnx/pull/446
464-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False)
463+
# Dropout with ratio 1.0 will be optimized so that only one Identity is left
464+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False,
465+
graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and
466+
check_op_count(g, "RandomUniformLike", 0)))
465467

466468
@check_tf_min_version("1.13")
467469
def test_nn_dropout_with_rate(self):
@@ -477,7 +479,9 @@ def test_nn_dropout_with_rate(self):
477479
feed_dict = {"input_1:0": x_val}
478480
input_names_with_port = ["input_1:0"]
479481
output_names_with_port = ["output:0"]
480-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False)
482+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False,
483+
graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and
484+
check_op_count(g, "RandomUniformLike", 0)))
481485

482486
def test_conv2d_with_input_transpose(self):
483487
x_shape = [2, 32, 32, 3]
@@ -925,6 +929,22 @@ def test_leaky_relu_int(self):
925929
self._run_test_case([_OUTPUT], {_INPUT: x_val})
926930
tf.reset_default_graph()
927931

932+
@skip_caffe2_backend("fails on caffe2 with dim issue")
933+
@check_onnxruntime_incompatibility("Mul")
934+
def test_leaky_relu_with_dependency(self):
935+
x_val = 1000 * np.random.random_sample([1000, 100]).astype(np.float32)
936+
x = tf.placeholder(x_val.dtype, [None] * x_val.ndim, name=_TFINPUT)
937+
# simulate leaky_relu
938+
alpha = tf.constant(0.5)
939+
y = alpha * x
940+
x_ = tf.maximum(y, x)
941+
dependency = y - 1
942+
943+
_ = tf.identity(x_, name=_TFOUTPUT)
944+
_ = tf.identity(dependency, name=_TFOUTPUT1)
945+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
946+
tf.reset_default_graph()
947+
928948
@skip_caffe2_backend("fails on caffe2 with dim issue")
929949
@check_onnxruntime_incompatibility("Mul")
930950
def test_leaky_relu_float(self):
@@ -1057,6 +1077,15 @@ def test_slice(self):
10571077
_ = tf.identity(x_, name=_TFOUTPUT)
10581078
self._run_test_case([_OUTPUT], {_INPUT: x_val})
10591079

1080+
def test_slice_neg_size(self):
1081+
x_val = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
1082+
t1 = tf.constant([0, 1], dtype=tf.int32)
1083+
t2 = tf.constant([-1, 2], dtype=tf.int32)
1084+
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1085+
x_ = tf.slice(x0, t1, t2)
1086+
_ = tf.identity(x_, name=_TFOUTPUT)
1087+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1088+
10601089
@check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'")
10611090
def test_slice_with_non_const(self):
10621091
x_val = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
@@ -2237,7 +2266,6 @@ def test_sparse_softmax_cross_entropy_with_logits_large_class(self):
22372266

22382267
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val}, rtol=1e-6)
22392268

2240-
@skip_onnxruntime_backend("onnxruntime Slice did not supported BOOL")
22412269
def test_matrix_band_part(self):
22422270
input_val = np.random.randint(0, 666, (10, 15)).astype(np.int32)
22432271
input_x = tf.placeholder(dtype=tf.int32, shape=[None, None], name=_TFINPUT)
@@ -2247,7 +2275,6 @@ def test_matrix_band_part(self):
22472275
_ = tf.identity(res1, name=_TFOUTPUT1)
22482276
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: input_val})
22492277

2250-
@skip_onnxruntime_backend("onnxruntime Slice did not supported BOOL.")
22512278
def test_matrix_band_part_2(self):
22522279
input_val = np.random.randint(0, 666, (1, 1)).astype(np.int32)
22532280
input_x = tf.placeholder(dtype=tf.int32, shape=[None, None], name=_TFINPUT)
@@ -2416,7 +2443,7 @@ def test_softsign(self):
24162443

24172444
def test_batch_to_spacend(self):
24182445
block_size = [2, 2]
2419-
crop = [[0, 1], [2, 1]]
2446+
crop = [[1, 0], [2, 1]]
24202447

24212448
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)
24222449
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
@@ -2551,6 +2578,80 @@ def test_selu(self):
25512578
_ = tf.identity(y, name=_TFOUTPUT)
25522579
self._run_test_case([_OUTPUT], {_INPUT: x_val})
25532580

2581+
# test for gemm pattern0: alpha*A*B + beta*C
2582+
def test_gemm_pattern0(self):
2583+
max_number = 10
2584+
m = np.random.randint(max_number)
2585+
n = np.random.randint(max_number)
2586+
k = np.random.randint(max_number)
2587+
x_val1 = np.random.rand(m, n).astype("float32")
2588+
x_val2 = np.random.rand(n, k).astype("float32")
2589+
x_val3 = np.random.rand(m, k).astype("float32")
2590+
a = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
2591+
b = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
2592+
c = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT2)
2593+
alpha = tf.constant(1.0, dtype=tf.float32)
2594+
beta = tf.constant(2.0, dtype=tf.float32)
2595+
mul1 = tf.multiply(alpha, tf.matmul(a, b))
2596+
mul2 = tf.multiply(beta, c)
2597+
x_ = mul1 + mul2
2598+
_ = tf.identity(x_, name=_TFOUTPUT)
2599+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2600+
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
2601+
2602+
# test for gemm pattern1: alpha*A*B + C
2603+
def test_gemm_pattern1(self):
2604+
max_number = 10
2605+
m = np.random.randint(max_number)
2606+
n = np.random.randint(max_number)
2607+
k = np.random.randint(max_number)
2608+
x_val1 = np.random.rand(m, n).astype("float32")
2609+
x_val2 = np.random.rand(n, k).astype("float32")
2610+
x_val3 = np.random.rand(m, k).astype("float32")
2611+
a = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
2612+
b = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
2613+
c = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT2)
2614+
alpha = tf.constant(1.0, dtype=tf.float32)
2615+
x_ = tf.multiply(alpha, tf.matmul(a, b)) + c
2616+
_ = tf.identity(x_, name=_TFOUTPUT)
2617+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2618+
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
2619+
2620+
# test for gemm pattern2: A*B + beta*C
2621+
def test_gemm_pattern2(self):
2622+
max_number = 10
2623+
m = np.random.randint(max_number)
2624+
n = np.random.randint(max_number)
2625+
k = np.random.randint(max_number)
2626+
x_val1 = np.random.rand(m, n).astype("float32")
2627+
x_val2 = np.random.rand(n, k).astype("float32")
2628+
x_val3 = np.random.rand(m, k).astype("float32")
2629+
a = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
2630+
b = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
2631+
c = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT2)
2632+
beta = tf.constant(2.0, dtype=tf.float32)
2633+
x_ = tf.matmul(a, b) + tf.multiply(beta, c)
2634+
_ = tf.identity(x_, name=_TFOUTPUT)
2635+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2636+
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
2637+
2638+
# test for gemm pattern3: A*B + C
2639+
def test_gemm_pattern3(self):
2640+
max_number = 10
2641+
m = np.random.randint(max_number)
2642+
n = np.random.randint(max_number)
2643+
k = np.random.randint(max_number)
2644+
x_val1 = np.random.rand(m, n).astype("float32")
2645+
x_val2 = np.random.rand(n, k).astype("float32")
2646+
x_val3 = np.random.rand(m, k).astype("float32")
2647+
a = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
2648+
b = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
2649+
c = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT2)
2650+
x_ = tf.matmul(a, b) + c
2651+
_ = tf.identity(x_, name=_TFOUTPUT)
2652+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
2653+
graph_validator=lambda g: check_op_count(g, "Gemm", 1))
2654+
25542655
def test_graph_matcher(self):
25552656
shape = [2, 6]
25562657
x_val = np.random.random(shape).astype(np.float32)

0 commit comments

Comments
 (0)