@@ -1027,8 +1027,25 @@ def fused_batchnorm_op7(ctx, node, name, args):
1027
1027
# tf outputs: y, batch_mean, batch_var
1028
1028
# a: data_format, epsilon, is_training
1029
1029
# onnx inputs: X, scale, B, mean, variance, attributes: epsilon, momentum=0.9, spatial : 1
1030
- # output: mean, var, savedmean, savedvar,
1030
+ # output: y, mean, var, savedmean, savedvar,
1031
1031
nodes = conv_convert_inputs (ctx , node , with_kernel = False )
1032
+ scale_shape = ctx .get_shape (node .input [1 ])
1033
+ mean_shape = ctx .get_shape (node .input [3 ])
1034
+ var_shape = ctx .get_shape (node .input [4 ])
1035
+ val_type = utils .ONNX_TO_NUMPY_DTYPE [node .inputs [1 ].dtype ]
1036
+
1037
+ if mean_shape != scale_shape :
1038
+ new_mean_value = np .array (np .resize (node .inputs [3 ].get_tensor_value (), scale_shape ), dtype = val_type )
1039
+ new_mean_node_name = utils .make_name (node .name )
1040
+ new_mean_node = ctx .make_const (new_mean_node_name , "Const" , new_mean_value )
1041
+ node .input [3 ] = new_mean_node_name
1042
+
1043
+ if var_shape != scale_shape :
1044
+ new_var_value = np .array (np .resize (node .inputs [4 ].get_tensor_value (), scale_shape ), dtype = val_type )
1045
+ new_val_node_name = utils .make_name (node .name )
1046
+ new_var_node = ctx .make_const (new_val_node_name , "Const" , new_var_value )
1047
+ node .input [4 ] = new_val_node_name
1048
+
1032
1049
return nodes
1033
1050
1034
1051
0 commit comments