File tree Expand file tree Collapse file tree 2 files changed +13
-6
lines changed Expand file tree Collapse file tree 2 files changed +13
-6
lines changed Original file line number Diff line number Diff line change @@ -120,11 +120,12 @@ def get_parameter(self, name):
120
120
raise ValueError ("no Parameter name %s found" % name )
121
121
return param
122
122
123
- def create_tmp_variable (self , dtype ):
123
+ def create_tmp_variable (self , dtype , stop_gradient = False ):
124
124
return self .main_program .current_block ().create_var (
125
125
name = unique_name ("." .join ([self .name , 'tmp' ])),
126
126
dtype = dtype ,
127
- persistable = False )
127
+ persistable = False ,
128
+ stop_gradient = stop_gradient )
128
129
129
130
def create_variable (self , * args , ** kwargs ):
130
131
return self .main_program .current_block ().create_var (* args , ** kwargs )
Original file line number Diff line number Diff line change @@ -971,20 +971,26 @@ def batch_norm(input,
971
971
attr = helper .param_attr , shape = param_shape , dtype = dtype , is_bias = True )
972
972
973
973
mean = helper .create_global_variable (
974
- dtype = input .dtype , shape = param_shape , persistable = True )
974
+ dtype = input .dtype ,
975
+ shape = param_shape ,
976
+ persistable = True ,
977
+ stop_gradient = True )
975
978
helper .set_variable_initializer (var = mean , initializer = Constant (0.0 ))
976
979
977
980
variance = helper .create_global_variable (
978
- dtype = input .dtype , shape = param_shape , persistable = True )
981
+ dtype = input .dtype ,
982
+ shape = param_shape ,
983
+ persistable = True ,
984
+ stop_gradient = True )
979
985
helper .set_variable_initializer (var = variance , initializer = Constant (1.0 ))
980
986
981
987
# create output
982
988
# mean and mean_out share the same memory
983
989
mean_out = mean
984
990
# variance and variance out share the same memory
985
991
variance_out = variance
986
- saved_mean = helper .create_tmp_variable (dtype )
987
- saved_variance = helper .create_tmp_variable (dtype )
992
+ saved_mean = helper .create_tmp_variable (dtype = dtype , stop_gradient = True )
993
+ saved_variance = helper .create_tmp_variable (dtype = dtype , stop_gradient = True )
988
994
989
995
batch_norm_out = helper .create_tmp_variable (dtype )
990
996
You can’t perform that action at this time.
0 commit comments