Skip to content

Commit b27aa05

Browse files
authored
Fix a bug that multiple (conv, batch_norm) ops could not be optimized. (#2187)
Signed-off-by: Jay Zhang <[email protected]>
1 parent 554d90a commit b27aa05

File tree

7 files changed

+57
-36
lines changed

7 files changed

+57
-36
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ The common issues we run into we try to document here [Troubleshooting Guide](Tr
1717

1818
| Build Type | OS | Python | TensorFlow | ONNX opset | Status |
1919
| --- | --- | --- | --- | --- | --- |
20-
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.7-3.10 | 1.13-1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=main) |
21-
| Unit Test - Full | Linux, MacOS, Windows | 3.7-3.10 | 1.13-1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=main) | |
20+
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.7-3.10 | 1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=main) |
21+
| Unit Test - Full | Linux, MacOS, Windows | 3.7-3.10 | 1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=main) | |
2222
<br/>
2323

2424
## Supported Versions

ci_build/azure_pipelines/onnxruntime_nightly_test.yml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,6 @@ stages:
1616
- template: 'unit_test.yml'
1717
report_coverage: 'True'
1818

19-
- template: 'templates/job_generator.yml'
20-
parameters:
21-
platforms: ['linux', 'windows']
22-
python_versions: ['3.7']
23-
tf_versions: ['1.14.0']
24-
onnx_opsets: ['']
25-
onnx_backends: {onnxruntime: ['nightly']}
26-
job:
27-
steps:
28-
- template: 'unit_test.yml'
29-
report_coverage: 'True'
30-
3119
- template: 'templates/job_generator.yml'
3220
parameters:
3321
platforms: ['linux', 'windows']

ci_build/azure_pipelines/pretrained_model_test-matrix.yml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
# Pre-trained model test, full matrix
22

33
jobs:
4-
- template: 'templates/job_generator.yml'
5-
parameters:
6-
platforms: ['linux', 'windows']
7-
python_versions: ['3.7']
8-
tf_versions: ['1.14.0']
9-
job:
10-
steps:
11-
- template: 'pretrained_model_test.yml'
12-
134
- template: 'templates/job_generator.yml'
145
parameters:
156
platforms: ['linux', 'windows']

ci_build/azure_pipelines/unit_test-matrix.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ stages:
77
parameters:
88
platforms: ['linux', 'windows']
99
python_versions: ['3.7']
10-
tf_versions: ['1.14.0', '1.15.2']
10+
tf_versions: ['1.15.2']
1111
onnx_opsets: ['']
1212
job:
1313
steps:

ci_build/azure_pipelines/unit_test.yml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,6 @@ stages:
122122
- template: 'unit_test.yml'
123123
report_coverage: 'True'
124124

125-
- template: 'templates/job_generator.yml'
126-
parameters:
127-
platforms: ['windows']
128-
tf_versions: ['1.14.0']
129-
onnx_opsets: ['14']
130-
job:
131-
steps:
132-
- template: 'unit_test.yml'
133-
report_coverage: 'True'
134-
135125
- template: 'templates/job_generator.yml'
136126
parameters:
137127
python_versions: ['3.8']

tests/test_backend.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3087,6 +3087,57 @@ def graph_validator(g):
30873087

30883088
self._run_test_case(func_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator)
30893089

3090+
@check_opset_min_version(7, "batchnorm")
3091+
def test_multiple_conv2d_fused_batchnorm(self):
3092+
x_shape = [1, 28, 28, 2]
3093+
x_val = np.random.random_sample(x_shape).astype(np.float32)
3094+
w = np.array([[2., 1., 1.],
3095+
[1., 3., 1.],
3096+
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
3097+
# 2 channels for input and output
3098+
w = np.concatenate([w, w, w, w]).reshape([3, 3, 2, 2])
3099+
scale_dtype = np.float32
3100+
scale_shape = x_shape[-1:]
3101+
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
3102+
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
3103+
mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
3104+
var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
3105+
3106+
def func_conv2d(x):
3107+
kernel = tf.constant(w, dtype=tf.float32, name='k')
3108+
conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
3109+
return conv
3110+
3111+
def func_multiple_fusedbn(x):
3112+
scale = tf.constant(scale_val, name='scale')
3113+
offset = tf.constant(offset_val, name='offset')
3114+
mean = tf.constant(mean_val, name='mean')
3115+
var = tf.constant(var_val, name='variance')
3116+
epsilon = 0.1234
3117+
y, _, _ = fused_batch_norm(
3118+
func_conv2d(x), scale, offset, mean=mean, variance=var,
3119+
epsilon=epsilon, data_format='NHWC', is_training=False)
3120+
3121+
y = tf.nn.relu(y)
3122+
3123+
y, _, _ = fused_batch_norm(
3124+
func_conv2d(y), scale, offset, mean=mean, variance=var,
3125+
epsilon=epsilon, data_format='NHWC', is_training=False)
3126+
3127+
y, _, _ = fused_batch_norm(
3128+
func_conv2d(y), scale, offset, mean=mean, variance=var,
3129+
epsilon=epsilon, data_format='NHWC', is_training=False)
3130+
3131+
return tf.identity(y, name=_TFOUTPUT)
3132+
3133+
def graph_validator(g):
3134+
if 'BatchNormalization' in [n.type for n in g.get_nodes()]:
3135+
return False
3136+
return True
3137+
3138+
self._run_test_case(func_multiple_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05,
3139+
graph_validator=graph_validator)
3140+
30903141
@check_tf_min_version("1.15")
30913142
@check_opset_min_version(10, "quantize_and_dequantize")
30923143
def test_qdq_unsigned_input(self):

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def _optimize_at_current_graph_level(self, g):
4141

4242
# topological sort of candidates
4343
# simplifying assumption for back-to-back-optimizer is
44-
# the op_types have 1 input, 1 output, but multiple consumers
44+
# the op_types have 1 input, 1 output, but multiple consumers.
45+
# if optype contains 2 elements, the second element should not be considered as a consumer.
4546
has_dependencies = set()
46-
consumer_node_ids = {n.output[0]: [] for n in nodes}
47+
consumer_node_ids = {n.output[0]: [] for n in nodes if len(optype) < 2 or n.type == optype[0]}
4748
for n in nodes:
4849
if n.input[0] in consumer_node_ids:
4950
consumer_node_ids[n.input[0]].extend([n])

0 commit comments

Comments
 (0)