7
7
8
8
from __future__ import unicode_literals
9
9
10
+ import numpy as np
10
11
from tf2onnx .utils import ONNX_DTYPE_NAMES # lgtm[py/unsafe-cyclic-import]
11
12
from .optimizer_base import GraphOptimizerBase # lgtm[py/unsafe-cyclic-import]
12
13
@@ -152,7 +153,6 @@ def _optimize_transpose(g, node, consumer_nodes):
152
153
@_register_func (('Squeeze' , 'Unsqueeze' ))
153
154
def _optimize_squeeze_unsqueeze (g , node , consumer_nodes ):
154
155
"""remove pairs of squeeze-unsqueeze nodes"""
155
-
156
156
if node .type != 'Squeeze' or len (consumer_nodes ) != 1 :
157
157
# no need to return any value, since not removing long chain of nodes
158
158
return []
@@ -177,3 +177,67 @@ def _optimize_squeeze_unsqueeze(g, node, consumer_nodes):
177
177
g .remove_node (node .name )
178
178
g .remove_node (node2 .name )
179
179
return []
180
+
181
+ @staticmethod
182
+ @_register_func (('Conv' , 'BatchNormalization' ))
183
+ def _optimize_conv_batchnorm_fusion (g , node , consumer_nodes ):
184
+ """fuse conv and batchnorm"""
185
+ if node .type != 'Conv' or len (consumer_nodes ) != 1 :
186
+ # can only fuse 1 conv + batchnorm
187
+ return []
188
+
189
+ node2 = consumer_nodes [0 ]
190
+ if node2 .type != 'BatchNormalization' :
191
+ return []
192
+
193
+ # if batchnorm is a graph output, skip
194
+ if set (node2 .output ) & set (g .outputs ):
195
+ return []
196
+
197
+ if not node .inputs [1 ].is_const ():
198
+ return []
199
+ weights = node .inputs [1 ].get_tensor_value (as_list = False )
200
+ # if not 4D, NCHW skip
201
+ if len (weights .shape ) != 4 :
202
+ return []
203
+
204
+ bias = 0
205
+ # optional bias value
206
+ if len (node .inputs ) > 2 :
207
+ if not node .inputs [2 ].is_const ():
208
+ return []
209
+ bias = node .inputs [2 ].get_tensor_value (as_list = False )
210
+
211
+ # scale, offset, mean, var be const, otherwise skip
212
+ if False in [node2 .inputs [i ].is_const () for i in [1 , 2 , 3 , 4 ]]:
213
+ return []
214
+
215
+ # if bn outputs used elsewhere, cannot fuse
216
+ for i in range (1 , len (node2 .output )):
217
+ if g .find_output_consumers (node2 .output [i ]):
218
+ return []
219
+
220
+ weights = weights .transpose (2 , 3 , 1 , 0 )
221
+ scale = node2 .inputs [1 ].get_tensor_value (as_list = False )
222
+ offset = node2 .inputs [2 ].get_tensor_value (as_list = False )
223
+ mean = node2 .inputs [3 ].get_tensor_value (as_list = False )
224
+ var = node2 .inputs [4 ].get_tensor_value (as_list = False )
225
+ epsilon = node2 .get_attr ('epsilon' ).f
226
+
227
+ scale_new = scale / np .sqrt (var + epsilon )
228
+ weights_new = weights * scale_new
229
+ weights_new = weights_new .transpose (3 , 2 , 0 , 1 )
230
+ bias_new = (bias - mean ) * scale_new + offset
231
+ bias_new_const = g .make_const (node .name + '_bias_fused_bn' , bias_new )
232
+ weights_new_const = g .make_const (node .name + '_weights_fused_bn' , weights_new )
233
+ node .input = [node .input [0 ], weights_new_const .output [0 ], bias_new_const .output [0 ]]
234
+
235
+ # fuse conv and bn, delete bn
236
+ node2_output = node2 .output [:1 ]
237
+ node2_shape = g .get_shape (node2 .output [0 ])
238
+ node2_dtype = g .get_dtype (node2 .output [0 ])
239
+ g .remove_node (node2 .name )
240
+ node .output = node2_output
241
+ g .set_shape (node2_output [0 ], node2_shape )
242
+ g .set_dtype (node2_output [0 ], node2_dtype )
243
+ return []
0 commit comments