@@ -114,23 +114,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
114
114
const auto *bias = ctx.Input <Tensor>(" Bias" );
115
115
116
116
auto *y = ctx.Output <Tensor>(" Y" );
117
- auto *mean_out = ctx.Output <Tensor>(" MeanOut" );
118
- auto *variance_out = ctx.Output <Tensor>(" VarianceOut" );
119
- auto *saved_mean = ctx.Output <Tensor>(" SavedMean" );
120
- auto *saved_variance = ctx.Output <Tensor>(" SavedVariance" );
121
117
122
118
// alloc memory
123
119
y->mutable_data <T>(ctx.GetPlace ());
124
- mean_out->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
125
- variance_out->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
126
- saved_mean->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
127
- saved_variance->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
128
120
129
121
auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
130
- math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
131
- functor;
132
- functor (dev_ctx, saved_mean, static_cast <BatchNormParamType<T>>(0 ));
133
- functor (dev_ctx, saved_variance, static_cast <BatchNormParamType<T>>(0 ));
134
122
135
123
auto handle = dev_ctx.cudnn_handle ();
136
124
@@ -159,6 +147,21 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
159
147
// Run training mode.
160
148
// obtain running mean and running inv var, and see if we need to
161
149
// initialize them.
150
+
151
+ auto *mean_out = ctx.Output <Tensor>(" MeanOut" );
152
+ auto *variance_out = ctx.Output <Tensor>(" VarianceOut" );
153
+ mean_out->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
154
+ variance_out->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
155
+
156
+ auto *saved_mean = ctx.Output <Tensor>(" SavedMean" );
157
+ auto *saved_variance = ctx.Output <Tensor>(" SavedVariance" );
158
+ saved_mean->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
159
+ saved_variance->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
160
+ math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
161
+ functor;
162
+ functor (dev_ctx, saved_mean, static_cast <BatchNormParamType<T>>(0 ));
163
+ functor (dev_ctx, saved_variance, static_cast <BatchNormParamType<T>>(0 ));
164
+
162
165
double this_factor = 1 . - momentum;
163
166
164
167
CUDNN_ENFORCE (platform::dynload::cudnnBatchNormalizationForwardTraining (
0 commit comments