Skip to content

Commit fe341ba

Browse files
authored
refine batch norm python layer (#7348)
1 parent 1566af8 commit fe341ba

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

python/paddle/v2/fluid/layer_helper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,12 @@ def get_parameter(self, name):
120120
raise ValueError("no Parameter name %s found" % name)
121121
return param
122122

123-
def create_tmp_variable(self, dtype):
123+
def create_tmp_variable(self, dtype, stop_gradient=False):
124124
return self.main_program.current_block().create_var(
125125
name=unique_name(".".join([self.name, 'tmp'])),
126126
dtype=dtype,
127-
persistable=False)
127+
persistable=False,
128+
stop_gradient=stop_gradient)
128129

129130
def create_variable(self, *args, **kwargs):
130131
return self.main_program.current_block().create_var(*args, **kwargs)

python/paddle/v2/fluid/layers/nn.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -971,20 +971,26 @@ def batch_norm(input,
971971
attr=helper.param_attr, shape=param_shape, dtype=dtype, is_bias=True)
972972

973973
mean = helper.create_global_variable(
974-
dtype=input.dtype, shape=param_shape, persistable=True)
974+
dtype=input.dtype,
975+
shape=param_shape,
976+
persistable=True,
977+
stop_gradient=True)
975978
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
976979

977980
variance = helper.create_global_variable(
978-
dtype=input.dtype, shape=param_shape, persistable=True)
981+
dtype=input.dtype,
982+
shape=param_shape,
983+
persistable=True,
984+
stop_gradient=True)
979985
helper.set_variable_initializer(var=variance, initializer=Constant(1.0))
980986

981987
# create output
982988
# mean and mean_out share the same memory
983989
mean_out = mean
984990
# variance and variance out share the same memory
985991
variance_out = variance
986-
saved_mean = helper.create_tmp_variable(dtype)
987-
saved_variance = helper.create_tmp_variable(dtype)
992+
saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
993+
saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
988994

989995
batch_norm_out = helper.create_tmp_variable(dtype)
990996

0 commit comments

Comments
 (0)