Skip to content

Commit 3c7a848

Browse files
committed
Add fusion for conv + batchnorm
1 parent aeaaa10 commit 3c7a848

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

tests/test_backend.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,45 @@ def func(x, mean, offset, var):
18701870
return tf.identity(y, name=_TFOUTPUT)
18711871
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: mean_val, _INPUT2: offset_val, _INPUT3: var_val})
18721872

1873+
@check_opset_min_version(7, "batchnorm")
1874+
def test_conv2d_batchnorm_fusion(self):
1875+
x_shape = [1, 28, 28, 2]
1876+
x_val = np.random.random_sample(x_shape).astype(np.float32)
1877+
w = np.array([[2., 1., 1.],
1878+
[1., 3., 1.],
1879+
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
1880+
# 2 channels for input and output
1881+
w = np.concatenate([w, w, w, w]).reshape([3, 3, 2, 2])
1882+
scale_dtype = np.float32
1883+
scale_shape = x_shape[-1:]
1884+
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
1885+
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
1886+
mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
1887+
var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
1888+
1889+
def func_conv2d(x):
1890+
kernel = tf.constant(w, dtype=tf.float32, name='k')
1891+
conv = tf.nn.conv2d(x, kernel, padding='VALID')
1892+
return conv
1893+
1894+
def func_fusedbn(x):
1895+
scale = tf.constant(scale_val, name='scale')
1896+
offset = tf.constant(offset_val, name='offset')
1897+
mean = tf.constant(mean_val, name='mean')
1898+
var = tf.constant(var_val, name='variance')
1899+
epsilon = 0.1234
1900+
y, _, _ = fused_batch_norm(
1901+
func_conv2d(x), scale, offset, mean=mean, variance=var,
1902+
epsilon=epsilon, data_format='NHWC', is_training=False)
1903+
return tf.identity(y, name=_TFOUTPUT)
1904+
1905+
def graph_validator(g):
1906+
if 'BatchNormalization' in [n.type for n in g.get_nodes()]:
1907+
return False
1908+
return True
1909+
1910+
self._run_test_case(func_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator)
1911+
18731912
@skip_caffe2_backend()
18741913
@check_opset_min_version(7, "resize_nearest_neighbor")
18751914
def test_resize_nearest_neighbor(self):

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import unicode_literals
99

10+
import numpy as np
1011
from tf2onnx.utils import ONNX_DTYPE_NAMES # lgtm[py/unsafe-cyclic-import]
1112
from .optimizer_base import GraphOptimizerBase # lgtm[py/unsafe-cyclic-import]
1213

@@ -152,7 +153,6 @@ def _optimize_transpose(g, node, consumer_nodes):
152153
@_register_func(('Squeeze', 'Unsqueeze'))
153154
def _optimize_squeeze_unsqueeze(g, node, consumer_nodes):
154155
"""remove pairs of squeeze-unsqueeze nodes"""
155-
156156
if node.type != 'Squeeze' or len(consumer_nodes) != 1:
157157
# no need to return any value, since not removing long chain of nodes
158158
return []
@@ -177,3 +177,67 @@ def _optimize_squeeze_unsqueeze(g, node, consumer_nodes):
177177
g.remove_node(node.name)
178178
g.remove_node(node2.name)
179179
return []
180+
181+
@staticmethod
182+
@_register_func(('Conv', 'BatchNormalization'))
183+
def _optimize_conv_batchnorm_fusion(g, node, consumer_nodes):
184+
"""fuse conv and batchnorm"""
185+
if node.type != 'Conv' or len(consumer_nodes) != 1:
186+
# can only fuse 1 conv + batchnorm
187+
return []
188+
189+
node2 = consumer_nodes[0]
190+
if node2.type != 'BatchNormalization':
191+
return []
192+
193+
# if batchnorm is a graph output, skip
194+
if set(node2.output) & set(g.outputs):
195+
return []
196+
197+
if not node.inputs[1].is_const():
198+
return []
199+
weights = node.inputs[1].get_tensor_value(as_list=False)
200+
# if not 4D, NCHW skip
201+
if len(weights.shape) != 4:
202+
return []
203+
204+
bias = 0
205+
# optional bias value
206+
if len(node.inputs) > 2:
207+
if not node.inputs[2].is_const():
208+
return []
209+
bias = node.inputs[2].get_tensor_value(as_list=False)
210+
211+
# scale, offset, mean, var be const, otherwise skip
212+
if False in [node2.inputs[i].is_const() for i in [1, 2, 3, 4]]:
213+
return []
214+
215+
# if bn outputs used elsewhere, cannot fuse
216+
for i in range(1, len(node2.output)):
217+
if g.find_output_consumers(node2.output[i]):
218+
return []
219+
220+
weights = weights.transpose(2, 3, 1, 0)
221+
scale = node2.inputs[1].get_tensor_value(as_list=False)
222+
offset = node2.inputs[2].get_tensor_value(as_list=False)
223+
mean = node2.inputs[3].get_tensor_value(as_list=False)
224+
var = node2.inputs[4].get_tensor_value(as_list=False)
225+
epsilon = node2.get_attr('epsilon').f
226+
227+
scale_new = scale / np.sqrt(var + epsilon)
228+
weights_new = weights * scale_new
229+
weights_new = weights_new.transpose(3, 2, 0, 1)
230+
bias_new = (bias - mean) * scale_new + offset
231+
bias_new_const = g.make_const(node.name + '_bias_fused_bn', bias_new)
232+
weights_new_const = g.make_const(node.name + '_weights_fused_bn', weights_new)
233+
node.input = [node.input[0], weights_new_const.output[0], bias_new_const.output[0]]
234+
235+
# fuse conv and bn, delete bn
236+
node2_output = node2.output[:1]
237+
node2_shape = g.get_shape(node2.output[0])
238+
node2_dtype = g.get_dtype(node2.output[0])
239+
g.remove_node(node2.name)
240+
node.output = node2_output
241+
g.set_shape(node2_output[0], node2_shape)
242+
g.set_dtype(node2_output[0], node2_dtype)
243+
return []

0 commit comments

Comments
 (0)