@@ -1018,16 +1018,21 @@ def initialize(self):
10181018 dims = inp .dim_names
10191019 self .add_output_variable (shape , dims )
10201020
1021- gamma = self .get_attr ('gamma_data' )
1022- beta = self .get_attr ('beta_data ' )
1023- mean = self .get_attr ('mean_data ' )
1024- var = self .get_attr ('variance_data' )
1025-
1026- scale = gamma / np . sqrt ( var + self . get_attr ( 'epsilon' ))
1027- bias = beta - scale * mean
1021+ if self .get_attr ('scale_data' ) is None :
1022+ gamma = self .get_attr ('gamma_data ' )
1023+ var = self .get_attr ('variance_data ' )
1024+ scale = gamma / np . sqrt ( var + self .get_attr ('epsilon' ) )
1025+ self . add_weights_variable ( name = 'scale' , var_name = 's{index}' , data = scale )
1026+ else :
1027+ self . add_weights_variable ( name = ' scale' , var_name = 's{index}' )
10281028
1029- self .add_weights_variable (name = 'scale' , var_name = 's{index}' , data = scale )
1030- self .add_weights_variable (name = 'bias' , var_name = 'b{index}' , data = bias )
1029+ if self .get_attr ('bias_data' ) is None :
1030+ beta = self .get_attr ('beta_data' )
1031+ mean = self .get_attr ('mean_data' )
1032+ bias = beta - scale * mean
1033+ self .add_weights_variable (name = 'bias' , var_name = 'b{index}' , data = bias )
1034+ else :
1035+ self .add_weights_variable (name = 'bias' , var_name = 'b{index}' )
10311036
10321037
10331038# TODO: discuss whether this should be renamed to soemthing more descriptive, and whether the class hierarchy makes sense
0 commit comments