Skip to content

Commit 353e46f

Browse files
authored
Merge pull request #1106 from onnx/gs/fix-f16-fuse
fix bn fuse for fp16
2 parents 010fad1 + 472770c commit 353e46f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,13 @@ def _optimize_conv_batchnorm_fusion(g, node, consumer_nodes):
201201
if len(weights.shape) != 4:
202202
return []
203203

204-
bias = 0
205204
# optional bias value
206205
if len(node.inputs) > 2:
207206
if not node.inputs[2].is_const():
208207
return []
209208
bias = node.inputs[2].get_tensor_value(as_list=False)
209+
else:
210+
bias = np.array(0, dtype=weights.dtype)
210211

211212
# scale, offset, mean, var be const, otherwise skip
212213
if False in [node2.inputs[i].is_const() for i in [1, 2, 3, 4]]:
@@ -228,8 +229,8 @@ def _optimize_conv_batchnorm_fusion(g, node, consumer_nodes):
228229
weights_new = weights * scale_new
229230
weights_new = weights_new.transpose(3, 2, 0, 1)
230231
bias_new = (bias - mean) * scale_new + offset
231-
bias_new_const = g.make_const(node.name + '_bias_fused_bn', bias_new)
232-
weights_new_const = g.make_const(node.name + '_weights_fused_bn', weights_new)
232+
bias_new_const = g.make_const(node.name + '_bias_fused_bn', bias_new.astype(bias.dtype))
233+
weights_new_const = g.make_const(node.name + '_weights_fused_bn', weights_new.astype(weights.dtype))
233234
g.replace_inputs(node, [node.input[0], weights_new_const.output[0], bias_new_const.output[0]])
234235

235236
# fuse conv and bn, delete bn

0 commit comments

Comments
 (0)