Skip to content

Commit 10fbb83

Browse files
authored
Skip BatchNorm when feature only has 1 element. (#11578)
* Fix batch norm when only 1 elements in normzalize dimension during training.
1 parent 110c6ae commit 10fbb83

File tree

5 files changed

+66
-40
lines changed

5 files changed

+66
-40
lines changed

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,18 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
216216
saved_mean_e.setZero();
217217
saved_variance_e.setZero();
218218

219+
EigenVectorArrayMap<T> running_mean_arr(
220+
mean_out->mutable_data<T>(ctx.GetPlace()), C);
221+
EigenVectorArrayMap<T> running_var_arr(
222+
variance_out->mutable_data<T>(ctx.GetPlace()), C);
223+
224+
if ((N * sample_size) == 1) {
225+
LOG(WARNING) << "Only 1 element in normalization dimension, "
226+
<< "we skip the batch norm calculation, let y = x.";
227+
framework::TensorCopySync(*x, ctx.GetPlace(), y);
228+
return;
229+
}
230+
219231
switch (data_layout) {
220232
case DataLayout::kNCHW: {
221233
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
@@ -247,10 +259,6 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
247259
PADDLE_THROW("Unknown storage order: %s", data_layout_str);
248260
}
249261

250-
EigenVectorArrayMap<T> running_mean_arr(
251-
mean_out->mutable_data<T>(ctx.GetPlace()), C);
252-
EigenVectorArrayMap<T> running_var_arr(
253-
variance_out->mutable_data<T>(ctx.GetPlace()), C);
254262
running_mean_arr =
255263
running_mean_arr * momentum + saved_mean_e * (1. - momentum);
256264
running_var_arr =
@@ -427,6 +435,11 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
427435
d_bias_arr.setZero();
428436
d_scale_arr.setZero();
429437

438+
if ((N * sample_size) == 1) {
439+
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
440+
return;
441+
}
442+
430443
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size);
431444

432445
switch (data_layout) {

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
7272
int N, C, H, W, D;
7373
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
7474

75+
auto *y = ctx.Output<Tensor>("Y");
76+
y->mutable_data<T>(ctx.GetPlace());
77+
7578
// ------------------- cudnn descriptors ---------------------
7679
cudnnTensorDescriptor_t data_desc_;
7780
cudnnTensorDescriptor_t bn_param_desc_;
@@ -93,7 +96,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
9396
mode_ = CUDNN_BATCHNORM_SPATIAL;
9497
#endif
9598

96-
VLOG(1) << "Setting descriptors.";
99+
VLOG(3) << "Setting descriptors.";
97100
std::vector<int> dims;
98101
std::vector<int> strides;
99102
if (data_layout == DataLayout::kNCHW) {
@@ -113,11 +116,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
113116
const auto *scale = ctx.Input<Tensor>("Scale");
114117
const auto *bias = ctx.Input<Tensor>("Bias");
115118

116-
auto *y = ctx.Output<Tensor>("Y");
117-
118-
// alloc memory
119-
y->mutable_data<T>(ctx.GetPlace());
120-
121119
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
122120

123121
auto handle = dev_ctx.cudnn_handle();
@@ -162,22 +160,28 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
162160
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
163161
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
164162

165-
double this_factor = 1. - momentum;
166-
167-
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
168-
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
169-
data_desc_, x->template data<T>(), data_desc_,
170-
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
171-
scale->template data<BatchNormParamType<T>>(),
172-
bias->template data<BatchNormParamType<T>>(), this_factor,
173-
mean_out->template mutable_data<BatchNormParamType<T>>(
174-
ctx.GetPlace()),
175-
variance_out->template mutable_data<BatchNormParamType<T>>(
176-
ctx.GetPlace()),
177-
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
178-
ctx.GetPlace()),
179-
saved_variance->template mutable_data<BatchNormParamType<T>>(
180-
ctx.GetPlace())));
163+
if ((N * H * W * D) == 1) {
164+
LOG(WARNING) << "Only 1 element in normalization dimension, "
165+
<< "we skip the batch norm calculation, let y = x.";
166+
framework::TensorCopySync(*x, ctx.GetPlace(), y);
167+
} else {
168+
double this_factor = 1. - momentum;
169+
170+
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
171+
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
172+
data_desc_, x->template data<T>(), data_desc_,
173+
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
174+
scale->template data<BatchNormParamType<T>>(),
175+
bias->template data<BatchNormParamType<T>>(), this_factor,
176+
mean_out->template mutable_data<BatchNormParamType<T>>(
177+
ctx.GetPlace()),
178+
variance_out->template mutable_data<BatchNormParamType<T>>(
179+
ctx.GetPlace()),
180+
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
181+
ctx.GetPlace()),
182+
saved_variance->template mutable_data<BatchNormParamType<T>>(
183+
ctx.GetPlace())));
184+
}
181185
}
182186

183187
// clean when exit.
@@ -209,6 +213,25 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
209213
int N, C, H, W, D;
210214
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
211215

216+
// init output
217+
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
218+
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
219+
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
220+
221+
d_x->mutable_data<T>(ctx.GetPlace());
222+
d_scale->mutable_data<T>(ctx.GetPlace());
223+
d_bias->mutable_data<T>(ctx.GetPlace());
224+
225+
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
226+
if ((N * H * W * D) == 1) {
227+
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
228+
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
229+
functor;
230+
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
231+
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
232+
return;
233+
}
234+
212235
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
213236
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
214237

@@ -247,21 +270,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
247270
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
248271
bn_param_desc_, data_desc_, mode_));
249272

250-
// init output
251-
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
252-
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
253-
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
254-
255-
d_x->mutable_data<T>(ctx.GetPlace());
256-
d_scale->mutable_data<T>(ctx.GetPlace());
257-
d_bias->mutable_data<T>(ctx.GetPlace());
258-
259273
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
260274
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
261275
const void *saved_mean_data = saved_mean->template data<T>();
262276
const void *saved_var_data = saved_var->template data<T>();
263277

264-
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
265278
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
266279
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
267280
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),

paddle/fluid/operators/cross_entropy_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
124124
"Tensor<float/double> with shape [N x D].");
125125
AddOutput("Y",
126126
"(Tensor, default Tensor<float>), a 2-D tensor with shape "
127-
"[N x 1]. The cross entropy loss.")
128-
.Reuse("X");
127+
"[N x 1]. The cross entropy loss.");
129128
AddAttr<bool>("soft_label",
130129
"(bool, default false), a flag indicating whether to "
131130
"interpretate the given labels as soft labels.")

python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def setUp(self):
4040
self.op_type = "fake_dequantize_max_abs"
4141
x = np.random.randn(31, 65).astype("float32")
4242
yq, scale = quantize_max_abs(x, self.num_bits)
43-
print 'scale ', scale
4443
ydq = dequantize_max_abs(yq, self.num_bits, scale)
4544

4645
self.inputs = {'X': yq}

python/paddle/fluid/tests/unittests/test_parallel_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def _run_test_impl_(self,
113113
generator = callback()
114114
# Automatically insert parallel do if use_parallel = True
115115
if use_parallel:
116-
places = fluid.layers.get_places()
116+
thread_num = fluid.core.get_cuda_device_count(
117+
) if use_gpu else 8
118+
places = fluid.layers.get_places(thread_num)
117119
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
118120
data = next(generator)
119121

0 commit comments

Comments
 (0)