Skip to content

Commit 0c67734

Browse files
authored
Merge pull request #99 from pengwa/batchnorm-mean-var
fix batch norm mean/var dimention issue
2 parents 055419e + 2e39b0c commit 0c67734

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

tf2onnx/tfonnx.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,8 +1027,25 @@ def fused_batchnorm_op7(ctx, node, name, args):
10271027
# tf outputs: y, batch_mean, batch_var
10281028
# a: data_format, epsilon, is_training
10291029
# onnx inputs: X, scale, B, mean, variance, attributes: epsilon, momentum=0.9, spatial : 1
1030-
# output: mean, var, savedmean, savedvar,
1030+
# output: y, mean, var, savedmean, savedvar,
10311031
nodes = conv_convert_inputs(ctx, node, with_kernel=False)
1032+
scale_shape = ctx.get_shape(node.input[1])
1033+
mean_shape = ctx.get_shape(node.input[3])
1034+
var_shape = ctx.get_shape(node.input[4])
1035+
val_type = utils.ONNX_TO_NUMPY_DTYPE[node.inputs[1].dtype]
1036+
1037+
if mean_shape != scale_shape:
1038+
new_mean_value = np.array(np.resize(node.inputs[3].get_tensor_value(), scale_shape), dtype=val_type)
1039+
new_mean_node_name= utils.make_name(node.name)
1040+
new_mean_node = ctx.make_const(new_mean_node_name, "Const", new_mean_value)
1041+
node.input[3] = new_mean_node_name
1042+
1043+
if var_shape != scale_shape:
1044+
new_var_value = np.array(np.resize(node.inputs[4].get_tensor_value(), scale_shape), dtype=val_type)
1045+
new_val_node_name= utils.make_name(node.name)
1046+
new_var_node = ctx.make_const(new_val_node_name, "Const", new_var_value)
1047+
node.input[4] = new_val_node_name
1048+
10321049
return nodes
10331050

10341051

0 commit comments

Comments
 (0)