Skip to content

Commit 59cfa79

Browse files
hwangdeyuguschmuefatcat-z
authored
Fix keras Conv2D BiasAdd fuse (#1796)
* fix Conv2D Bias Add fuse Signed-off-by: hwangdeyu <[email protected]> * add the tf1 keras missing optimization Signed-off-by: hwangdeyu <[email protected]> * fix pylint Signed-off-by: hwangdeyu <[email protected]> * add skip_tf_cpu decorator Signed-off-by: hwangdeyu <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]> Co-authored-by: fatcat-z <[email protected]> * change to tf.Graph Signed-off-by: hwangdeyu <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]> Co-authored-by: fatcat-z <[email protected]> * fix grammmer comment Signed-off-by: hwangdeyu <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]> Co-authored-by: fatcat-z <[email protected]>
1 parent 5cd3b5b commit 59cfa79

File tree

7 files changed

+69
-31
lines changed

7 files changed

+69
-31
lines changed

tests/keras2onnx_applications/nightly_build/test_nlp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import unittest
66
import mock_keras2onnx
77
import numpy as np
8-
from mock_keras2onnx.proto import keras, is_tf_keras
8+
from mock_keras2onnx.proto import keras, is_tensorflow_older_than
99
from os.path import dirname, abspath
1010
sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
1111
from test_utils import run_onnx_runtime
@@ -91,6 +91,7 @@ def test_babi_rnn(self):
9191
expected = model.predict([x, y])
9292
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, {model.input_names[0]: x, model.input_names[1]: y}, expected, self.model_files))
9393

94+
@unittest.skipIf(is_tensorflow_older_than('2.0.0'), "Result is slightly different in tf1")
9495
@unittest.skipIf(get_maximum_opset_supported() < 9,
9596
"None seq_length LSTM is not supported before opset 9.")
9697
def test_imdb_bidirectional_lstm(self):

tests/test_backend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,29 @@ def func(x):
740740
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
741741
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
742742

743+
@check_tf_min_version("1.15")
744+
@skip_tf_cpu("only tf_gpu can run conv2d with NCHW format")
745+
def test_conv2d_biasadd_rewriter(self):
746+
x_shape = [2, 3, 32, 16]
747+
x_val = make_xval(x_shape)
748+
def func(x):
749+
middles = tf.keras.layers.ZeroPadding2D(
750+
padding=(0, 4),
751+
data_format="channels_first",
752+
name="padding"
753+
)(x)
754+
t = tf.keras.layers.Conv2D(
755+
filters=768,
756+
kernel_size=3,
757+
strides=1,
758+
use_bias=True,
759+
data_format="channels_first",
760+
name="conv2d"
761+
)(middles)
762+
return tf.identity(t, name=_TFOUTPUT)
763+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
764+
graph_validator=lambda g: check_op_count(g, "Add", 0, disabled=False))
765+
743766
@check_tf_min_version("1.15")
744767
def test_conv2d_dilations_rewriter(self):
745768
x_shape = [2, 32, 16, 3]
@@ -2353,6 +2376,9 @@ def func(x):
23532376
return tf.identity(x_, name=_TFOUTPUT)
23542377
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
23552378

2379+
@skip_tflite("tflite does not support uint32 if tf version <= 2.3.0")
2380+
@check_opset_min_version(6, "cast")
2381+
def test_cast_unit32(self):
23562382
x_val = np.array([1, 2, 3, 4], dtype=np.uint32).reshape((2, 2))
23572383
def func(x):
23582384
x_ = tf.cast(x, tf.uint64)

tf2onnx/convert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,9 @@ def _from_keras_tf1(model, input_signature=None, opset=None, custom_ops=None, cu
373373

374374
with tf.device("/cpu:0"):
375375
frozen_graph, initialized_tables = tf_loader.freeze_session(sess, input_names, output_names, get_tables=True)
376+
with tf.Graph().as_default():
377+
tf.import_graph_def(frozen_graph, name="")
378+
frozen_graph = tf_loader.tf_optimize(input_names, output_names, frozen_graph, False)
376379
model_proto, external_tensor_storage = _convert_common(
377380
frozen_graph,
378381
name=model.name,

tf2onnx/rewriter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
__all__ = [
3030
"rewrite_cond",
31-
"rewrite_conv2d_with_pad",
3231
"rewrite_dropout",
3332
"rewrite_eye",
3433
"rewrite_flatten",
@@ -49,6 +48,7 @@
4948
"rewrite_quantize_and_dequantize",
5049
"rewrite_layer_normalization",
5150
"rewrite_conv_dilations",
51+
"rewrite_conv2d_with_pad",
5252
"rewrite_ragged_variant_shape",
5353
"rewriter_lstm_tf2",
5454
"rewrite_gru_tf2",

tf2onnx/rewriter/conv2d_with_add_rewriter.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,40 @@
1313
# pylint: disable=missing-docstring
1414

1515
def rewrite_biasadd_with_conv2d(g, ops):
16-
pattern = \
16+
pattern1 = \
1717
OpTypePattern('BiasAdd', name='biasadd', inputs=[
1818
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
19-
matcher = GraphMatcher(pattern)
20-
match_results = list(matcher.match_ops(ops))
21-
for match in match_results:
22-
biasadd = match.get_op('biasadd')
23-
conv = match.get_op('conv')
24-
25-
#backup the conv and biasadd values
26-
conv_type = conv.type
27-
conv_input = conv.input
28-
conv_attr = conv.attr
29-
dtype = g.get_dtype(conv.output[0])
30-
shape = g.get_shape(conv.output[0])
31-
conv_name = biasadd.name
32-
conv_output = biasadd.output
33-
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]
34-
35-
if len(g.find_output_consumers(conv.output[0])) > 1:
36-
continue
37-
# Remove the Conv and BiasAdd node
38-
g.remove_node(conv.name)
39-
g.remove_node(biasadd.name)
40-
41-
g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
42-
shapes=[shape], dtypes=[dtype], skip_conversion=False)
19+
pattern2 = \
20+
OpTypePattern('BiasAdd', name='biasadd', inputs=[
21+
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=[
22+
'*', '*', '*']), '*'], allow_reorder=True)
23+
24+
for pattern in [pattern1, pattern2]:
25+
matcher = GraphMatcher(pattern)
26+
match_results = list(matcher.match_ops(ops))
27+
for match in match_results:
28+
biasadd = match.get_op('biasadd')
29+
conv = match.get_op('conv')
30+
31+
# Backup the conv and biasadd values
32+
conv_type = conv.type
33+
conv_input = conv.input
34+
conv_attr = conv.attr
35+
dtype = g.get_dtype(conv.output[0])
36+
shape = g.get_shape(conv.output[0])
37+
conv_name = biasadd.name
38+
conv_output = biasadd.output
39+
if pattern == pattern2:
40+
conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]]
41+
else:
42+
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]
43+
44+
if len(g.find_output_consumers(conv.output[0])) > 1:
45+
continue
46+
# Remove the Conv and BiasAdd node
47+
g.remove_node(conv.name)
48+
g.remove_node(biasadd.name)
49+
50+
g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
51+
shapes=[shape], dtypes=[dtype], skip_conversion=False)
4352
return ops

tf2onnx/tf_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non
681681
rewrite_options = config.graph_options.rewrite_options
682682
config.graph_options.infer_shapes = True
683683
# TODO: if we turn on pruning, grappler removes some identities that the tf-1.x lstm rewriter
684-
# depends on so for now don't turn this on.
684+
# depends on so for now don't turn this on, fold_constant is always enabled now.
685685
rewrite_options.optimizers[:] = [
686686
# 'pruning', 'constfold', 'arithmetic', 'dependency', 'function',
687687
'constfold', 'function'

tf2onnx/version.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
4-
version = '1.8.0'
5-
git_version = '24080398ff4793ed8aac028ffa4b714a4803d7fb'
3+
version = '1.10.0'
4+
git_version = '219e00c073f6e73fba7335630dcf1f96cc82c983'

0 commit comments

Comments
 (0)