Skip to content

Commit bbaa8a3

Browse files
committed
Create a rewriter for fusing Conv2d with BiasAdd
1 parent e467d58 commit bbaa8a3

File tree

4 files changed

+47
-3
lines changed

4 files changed

+47
-3
lines changed

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
rewrite_custom_rnn_cell, rewrite_generic_loop
2121
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
2222
from tf2onnx.rewriter.transpose_rewriter import rewrite_transpose
23+
from tf2onnx.rewriter.conv2d_with_add_rewriter import rewrite_biasadd_with_conv2d
2324

2425

2526
__all__ = [
@@ -41,4 +42,5 @@
4142
"rewrite_bi_direction_gru",
4243
"rewrite_custom_rnn_cell",
4344
"rewrite_generic_loop",
45+
"rewrite_biasadd_with_conv2d",
4446
]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx conv2d op with BiasAdd
6+
"""
7+
from tf2onnx import logging
8+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
# pylint: disable=missing-docstring
14+
15+
def rewrite_biasadd_with_conv2d(g, ops):
16+
pattern = \
17+
OpTypePattern('BiasAdd', name='biasadd', inputs=[
18+
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
19+
matcher = GraphMatcher(pattern)
20+
match_results = list(matcher.match_ops(ops))
21+
for match in match_results:
22+
biasadd = match.get_op('biasadd')
23+
conv = match.get_op('conv')
24+
25+
#backup the conv and biasadd values
26+
conv_type = conv.type
27+
conv_input = conv.input
28+
conv_attr = conv.attr
29+
dtype = g.get_dtype(conv.output[0])
30+
shape = g.get_shape(conv.output[0])
31+
conv_name = biasadd.name
32+
conv_output = biasadd.output
33+
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]
34+
35+
# Remove the Conv and BiasAdd node
36+
g.remove_node(conv.name)
37+
g.remove_node(biasadd.name)
38+
39+
g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
40+
shapes=[shape], dtypes=[dtype], skip_conversion=False)
41+
return ops

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,7 @@ def compat_handler(ctx, node, **kwargs):
467467
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
468468
rewrite_single_direction_gru, rewrite_bi_direction_gru,
469469
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond,
470+
rewrite_biasadd_with_conv2d,
470471
]
471472

472473
if custom_rewriter is not None:

tf2onnx/version.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
2-
version = '1.6.0'
3-
git_version = '82f805f8fe7d2fa91e6ca9d39e153712f6887fec'
1+
2+
version = '1.6.0'
3+
git_version = 'be91553e30582216dfc67e1808e7eb42bb566541'

0 commit comments

Comments
 (0)