Skip to content

Commit eb61074

Browse files
author
zhangkaihuo
authored
[cherry-pick] Fix bn performance degradation (#50382)
att, cherry-pick: #48563 , #50287
1 parent 59fec5d commit eb61074

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,8 @@ void BatchNormGradRawKernel(const Context &ctx,
855855
}
856856
// CUDNN only support small batch size
857857
bool use_native_nhwc =
858-
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
858+
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC &&
859+
H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL)
859860
: false;
860861
const bool use_native_kernel =
861862
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
@@ -933,6 +934,21 @@ void BatchNormGradRawKernel(const Context &ctx,
933934
flag_ptr);
934935
}
935936
// 2. reduce_sum(x, dy, mean) => dscale, dbias
937+
BatchNormParamType<T> *dscale = nullptr;
938+
BatchNormParamType<T> *dbias = nullptr;
939+
bool with_scale = false;
940+
if (d_scale && d_bias) {
941+
dscale = ctx.template Alloc<BatchNormParamType<T>>(d_scale);
942+
dbias = ctx.template Alloc<BatchNormParamType<T>>(d_bias);
943+
} else {
944+
DenseTensor dscale_mem =
945+
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
946+
DenseTensor dbias_mem =
947+
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
948+
dscale = dscale_mem.data<BatchNormParamType<T>>();
949+
dbias = dbias_mem.data<BatchNormParamType<T>>();
950+
}
951+
936952
BNBackward2DChannelLastStage2<T, block_size>
937953
<<<grid, block, 0, ctx.stream()>>>(
938954
transformed_d_y.template data<T>(),
@@ -944,8 +960,8 @@ void BatchNormGradRawKernel(const Context &ctx,
944960
H * W * D,
945961
epsilon,
946962
block_data_ptr,
947-
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
948-
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
963+
dscale,
964+
dbias,
949965
flag_ptr);
950966

951967
// 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
@@ -954,8 +970,8 @@ void BatchNormGradRawKernel(const Context &ctx,
954970
transformed_d_y.template data<T>(),
955971
transformed_x.template data<T>(),
956972
scale.template data<BatchNormParamType<T>>(),
957-
d_scale->data<BatchNormParamType<T>>(),
958-
d_bias->data<BatchNormParamType<T>>(),
973+
dscale,
974+
dbias,
959975
mean_ptr,
960976
variance_ptr,
961977
C,
@@ -1165,6 +1181,7 @@ void BatchNormGradRawKernel(const Context &ctx,
11651181
paddle::platform::dynload::cudnnDestroyTensorDescriptor(
11661182
bn_param_desc_));
11671183
#endif
1184+
11681185
} else {
11691186
const auto *running_mean = mean.get_ptr();
11701187
const auto *running_var = variance.get_ptr();

0 commit comments

Comments
 (0)