@@ -201,12 +201,13 @@ def _optimize_conv_batchnorm_fusion(g, node, consumer_nodes):
201
201
if len (weights .shape ) != 4 :
202
202
return []
203
203
204
- bias = 0
205
204
# optional bias value
206
205
if len (node .inputs ) > 2 :
207
206
if not node .inputs [2 ].is_const ():
208
207
return []
209
208
bias = node .inputs [2 ].get_tensor_value (as_list = False )
209
+ else :
210
+ bias = np .array (0 , dtype = weights .dtype )
210
211
211
212
# scale, offset, mean, var be const, otherwise skip
212
213
if False in [node2 .inputs [i ].is_const () for i in [1 , 2 , 3 , 4 ]]:
@@ -228,8 +229,8 @@ def _optimize_conv_batchnorm_fusion(g, node, consumer_nodes):
228
229
weights_new = weights * scale_new
229
230
weights_new = weights_new .transpose (3 , 2 , 0 , 1 )
230
231
bias_new = (bias - mean ) * scale_new + offset
231
- bias_new_const = g .make_const (node .name + '_bias_fused_bn' , bias_new )
232
- weights_new_const = g .make_const (node .name + '_weights_fused_bn' , weights_new )
232
+ bias_new_const = g .make_const (node .name + '_bias_fused_bn' , bias_new . astype ( bias . dtype ) )
233
+ weights_new_const = g .make_const (node .name + '_weights_fused_bn' , weights_new . astype ( weights . dtype ) )
233
234
g .replace_inputs (node , [node .input [0 ], weights_new_const .output [0 ], bias_new_const .output [0 ]])
234
235
235
236
# fuse conv and bn, delete bn
0 commit comments