@@ -784,24 +784,49 @@ def version_6(cls, ctx, node, **kwargs):
784
784
785
785
conv_convert_inputs (ctx , node , with_kernel = False )
786
786
787
+ inp_shape = ctx .get_shape (node .input [0 ])
788
+ inp_rank = len (inp_shape ) if inp_shape is not None else None
787
789
scale_shape = ctx .get_shape (node .input [1 ])
788
790
mean_shape = ctx .get_shape (node .input [3 ])
789
791
var_shape = ctx .get_shape (node .input [4 ])
790
792
val_type = utils .map_onnx_to_numpy_type (ctx .get_dtype (node .input [1 ]))
791
-
792
- if node .get_attr_value ('is_training' , 1 ) == 1 :
793
+ is_training = node .get_attr_value ('is_training' , True )
794
+
795
+ if is_training and node .get_attr_value ('exponential_avg_factor' , 1.0 ) == 1.0 :
796
+ # Sometimes TF uses a BatchNorm op with training = True and exponential_avg_factor = 1.0
797
+ # to perform layer mean/variance normalization. In such cases, the mean/var are computed from the input.
798
+ # TF allows mean/variance to be excluded only if is_training and exponential_avg_factor == 1.0
799
+ utils .make_sure (inp_rank is not None , "Cannot convert node %s of type %s with input of unknown rank." ,
800
+ node .name , tf_type )
801
+ dims = [0 ] + list (range (2 , inp_rank ))
802
+ avg = ctx .make_node ("ReduceMean" , [node .input [0 ]], attr = {'axes' : dims , 'keepdims' : True }).output [0 ]
803
+ avg_squeezed = GraphBuilder (ctx ).make_squeeze ({"data" : avg , "axes" : dims })
804
+ sub = ctx .make_node ("Sub" , [node .input [0 ], avg ]).output [0 ]
805
+ var_squeezed = ctx .make_node ("ReduceSumSquare" , [sub ], attr = {'axes' : dims , 'keepdims' : False }).output [0 ]
806
+
807
+ inp_shape = ctx .make_node ("Shape" , [node .input [0 ]]).output [0 ]
808
+ dims_const = ctx .make_const (utils .make_name ("axes_const" ), np .array (dims , dtype = np .int64 )).output [0 ]
809
+ reduce_dims = ctx .make_node ("Gather" , [inp_shape , dims_const ]).output [0 ]
810
+ dims_product = ctx .make_node ("ReduceProd" , [reduce_dims ], attr = {'axes' : [0 ], 'keepdims' : False })
811
+ cnt_float = ctx .make_node ("Cast" , [dims_product .output [0 ]], attr = {'to' : ctx .get_dtype (node .input [0 ])})
812
+
813
+ pop_var_squeezed = ctx .make_node ("Div" , [var_squeezed , cnt_float .output [0 ]]).output [0 ]
814
+ ctx .replace_inputs (node , node .input [:3 ] + [avg_squeezed , pop_var_squeezed ])
815
+ else :
793
816
logger .warning ("Node %s of type %s has is_training set to true, which is not supperted. "
794
817
"Please re-save the model with training set to false." ,
795
818
node .name , tf_type )
819
+ # As long as the mean/variance estimates are provided, we should be OK
820
+ is_training = False
796
821
797
- if mean_shape != scale_shape and all (d >= 0 for d in scale_shape ):
822
+ if not is_training and mean_shape != scale_shape and all (d >= 0 for d in scale_shape ):
798
823
new_mean_value = np .array (np .resize (node .inputs [3 ].get_tensor_value (as_list = False ), scale_shape ),
799
824
dtype = val_type )
800
825
new_mean_node_name = utils .make_name (node .name )
801
826
ctx .make_const (new_mean_node_name , new_mean_value )
802
827
ctx .replace_input (node , node .input [3 ], new_mean_node_name , 3 )
803
828
804
- if var_shape != scale_shape and all (d >= 0 for d in scale_shape ):
829
+ if not is_training and var_shape != scale_shape and all (d >= 0 for d in scale_shape ):
805
830
new_var_value = np .array (np .resize (node .inputs [4 ].get_tensor_value (as_list = False ), scale_shape ),
806
831
dtype = val_type )
807
832
new_val_node_name = utils .make_name (node .name )
0 commit comments