@@ -72,6 +72,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
72
72
int N, C, H, W, D;
73
73
ExtractNCWHD (x_dims, data_layout, &N, &C, &H, &W, &D);
74
74
75
+ auto *y = ctx.Output <Tensor>(" Y" );
76
+ y->mutable_data <T>(ctx.GetPlace ());
77
+
75
78
// ------------------- cudnn descriptors ---------------------
76
79
cudnnTensorDescriptor_t data_desc_;
77
80
cudnnTensorDescriptor_t bn_param_desc_;
@@ -93,7 +96,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
93
96
mode_ = CUDNN_BATCHNORM_SPATIAL;
94
97
#endif
95
98
96
- VLOG (1 ) << " Setting descriptors." ;
99
+ VLOG (3 ) << " Setting descriptors." ;
97
100
std::vector<int > dims;
98
101
std::vector<int > strides;
99
102
if (data_layout == DataLayout::kNCHW ) {
@@ -113,11 +116,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
113
116
const auto *scale = ctx.Input <Tensor>(" Scale" );
114
117
const auto *bias = ctx.Input <Tensor>(" Bias" );
115
118
116
- auto *y = ctx.Output <Tensor>(" Y" );
117
-
118
- // alloc memory
119
- y->mutable_data <T>(ctx.GetPlace ());
120
-
121
119
auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
122
120
123
121
auto handle = dev_ctx.cudnn_handle ();
@@ -162,22 +160,28 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
162
160
functor (dev_ctx, saved_mean, static_cast <BatchNormParamType<T>>(0 ));
163
161
functor (dev_ctx, saved_variance, static_cast <BatchNormParamType<T>>(0 ));
164
162
165
- double this_factor = 1 . - momentum;
166
-
167
- CUDNN_ENFORCE (platform::dynload::cudnnBatchNormalizationForwardTraining (
168
- handle, mode_, CudnnDataType<T>::kOne (), CudnnDataType<T>::kZero (),
169
- data_desc_, x->template data <T>(), data_desc_,
170
- y->template mutable_data <T>(ctx.GetPlace ()), bn_param_desc_,
171
- scale->template data <BatchNormParamType<T>>(),
172
- bias->template data <BatchNormParamType<T>>(), this_factor,
173
- mean_out->template mutable_data <BatchNormParamType<T>>(
174
- ctx.GetPlace ()),
175
- variance_out->template mutable_data <BatchNormParamType<T>>(
176
- ctx.GetPlace ()),
177
- epsilon, saved_mean->template mutable_data <BatchNormParamType<T>>(
178
- ctx.GetPlace ()),
179
- saved_variance->template mutable_data <BatchNormParamType<T>>(
180
- ctx.GetPlace ())));
163
+ if ((N * H * W * D) == 1 ) {
164
+ LOG (WARNING) << " Only 1 element in normalization dimension, "
165
+ << " we skip the batch norm calculation, let y = x." ;
166
+ framework::TensorCopySync (*x, ctx.GetPlace (), y);
167
+ } else {
168
+ double this_factor = 1 . - momentum;
169
+
170
+ CUDNN_ENFORCE (platform::dynload::cudnnBatchNormalizationForwardTraining (
171
+ handle, mode_, CudnnDataType<T>::kOne (), CudnnDataType<T>::kZero (),
172
+ data_desc_, x->template data <T>(), data_desc_,
173
+ y->template mutable_data <T>(ctx.GetPlace ()), bn_param_desc_,
174
+ scale->template data <BatchNormParamType<T>>(),
175
+ bias->template data <BatchNormParamType<T>>(), this_factor,
176
+ mean_out->template mutable_data <BatchNormParamType<T>>(
177
+ ctx.GetPlace ()),
178
+ variance_out->template mutable_data <BatchNormParamType<T>>(
179
+ ctx.GetPlace ()),
180
+ epsilon, saved_mean->template mutable_data <BatchNormParamType<T>>(
181
+ ctx.GetPlace ()),
182
+ saved_variance->template mutable_data <BatchNormParamType<T>>(
183
+ ctx.GetPlace ())));
184
+ }
181
185
}
182
186
183
187
// clean when exit.
@@ -209,6 +213,25 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
209
213
int N, C, H, W, D;
210
214
ExtractNCWHD (x_dims, data_layout, &N, &C, &H, &W, &D);
211
215
216
+ // init output
217
+ auto *d_x = ctx.Output <Tensor>(framework::GradVarName (" X" ));
218
+ auto *d_scale = ctx.Output <Tensor>(framework::GradVarName (" Scale" ));
219
+ auto *d_bias = ctx.Output <Tensor>(framework::GradVarName (" Bias" ));
220
+
221
+ d_x->mutable_data <T>(ctx.GetPlace ());
222
+ d_scale->mutable_data <T>(ctx.GetPlace ());
223
+ d_bias->mutable_data <T>(ctx.GetPlace ());
224
+
225
+ auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
226
+ if ((N * H * W * D) == 1 ) {
227
+ framework::TensorCopySync (*d_y, ctx.GetPlace (), d_x);
228
+ math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
229
+ functor;
230
+ functor (dev_ctx, d_scale, static_cast <BatchNormParamType<T>>(0 ));
231
+ functor (dev_ctx, d_bias, static_cast <BatchNormParamType<T>>(0 ));
232
+ return ;
233
+ }
234
+
212
235
PADDLE_ENFORCE_EQ (scale->dims ().size (), 1UL );
213
236
PADDLE_ENFORCE_EQ (scale->dims ()[0 ], C);
214
237
@@ -247,21 +270,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
247
270
CUDNN_ENFORCE (platform::dynload::cudnnDeriveBNTensorDescriptor (
248
271
bn_param_desc_, data_desc_, mode_));
249
272
250
- // init output
251
- auto *d_x = ctx.Output <Tensor>(framework::GradVarName (" X" ));
252
- auto *d_scale = ctx.Output <Tensor>(framework::GradVarName (" Scale" ));
253
- auto *d_bias = ctx.Output <Tensor>(framework::GradVarName (" Bias" ));
254
-
255
- d_x->mutable_data <T>(ctx.GetPlace ());
256
- d_scale->mutable_data <T>(ctx.GetPlace ());
257
- d_bias->mutable_data <T>(ctx.GetPlace ());
258
-
259
273
const auto *saved_mean = ctx.Input <Tensor>(" SavedMean" );
260
274
const auto *saved_var = ctx.Input <Tensor>(" SavedVariance" );
261
275
const void *saved_mean_data = saved_mean->template data <T>();
262
276
const void *saved_var_data = saved_var->template data <T>();
263
277
264
- auto &dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
265
278
CUDNN_ENFORCE (platform::dynload::cudnnBatchNormalizationBackward (
266
279
dev_ctx.cudnn_handle (), mode_, CudnnDataType<T>::kOne (),
267
280
CudnnDataType<T>::kZero (), CudnnDataType<T>::kOne (),
0 commit comments