Skip to content

Commit 859fedf

Browse files
authored
Merge pull request #9871 from qingqing01/fix_bn
Refine batch_norm_op.
2 parents ddf5783 + 1204d9f commit 859fedf

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
114114
const auto *bias = ctx.Input<Tensor>("Bias");
115115

116116
auto *y = ctx.Output<Tensor>("Y");
117-
auto *mean_out = ctx.Output<Tensor>("MeanOut");
118-
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
119-
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
120-
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
121117

122118
// alloc memory
123119
y->mutable_data<T>(ctx.GetPlace());
124-
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
125-
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
126-
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
127-
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
128120

129121
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
130-
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
131-
functor;
132-
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
133-
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
134122

135123
auto handle = dev_ctx.cudnn_handle();
136124

@@ -159,6 +147,21 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
159147
// Run training mode.
160148
// obtain running mean and running inv var, and see if we need to
161149
// initialize them.
150+
151+
auto *mean_out = ctx.Output<Tensor>("MeanOut");
152+
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
153+
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
154+
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
155+
156+
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
157+
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
158+
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
159+
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
160+
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
161+
functor;
162+
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
163+
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
164+
162165
double this_factor = 1. - momentum;
163166

164167
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(

0 commit comments

Comments
 (0)