@@ -855,7 +855,8 @@ void BatchNormGradRawKernel(const Context &ctx,
855855 }
856856 // CUDNN only support small batch size
857857 bool use_native_nhwc =
858- d_x ? (x_dims.size () == 4 && compute_format == DataLayout::kNHWC )
858+ d_x ? (x_dims.size () == 4 && compute_format == DataLayout::kNHWC &&
859+ H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL)
859860 : false ;
860861 const bool use_native_kernel =
861862 ((x_dims.size () == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
@@ -933,6 +934,21 @@ void BatchNormGradRawKernel(const Context &ctx,
933934 flag_ptr);
934935 }
935936 // 2. reduce_sum(x, dy, mean) => dscale, dbias
937+ BatchNormParamType<T> *dscale = nullptr ;
938+ BatchNormParamType<T> *dbias = nullptr ;
939+ bool with_scale = false ;
940+ if (d_scale && d_bias) {
941+ dscale = ctx.template Alloc <BatchNormParamType<T>>(d_scale);
942+ dbias = ctx.template Alloc <BatchNormParamType<T>>(d_bias);
943+ } else {
944+ DenseTensor dscale_mem =
945+ phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
946+ DenseTensor dbias_mem =
947+ phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
948+ dscale = dscale_mem.data <BatchNormParamType<T>>();
949+ dbias = dbias_mem.data <BatchNormParamType<T>>();
950+ }
951+
936952 BNBackward2DChannelLastStage2<T, block_size>
937953 <<<grid, block, 0 , ctx.stream()>>> (
938954 transformed_d_y.template data <T>(),
@@ -944,8 +960,8 @@ void BatchNormGradRawKernel(const Context &ctx,
944960 H * W * D,
945961 epsilon,
946962 block_data_ptr,
947- ctx. template Alloc <BatchNormParamType<T>>(d_scale) ,
948- ctx. template Alloc <BatchNormParamType<T>>(d_bias) ,
963+ dscale ,
964+ dbias ,
949965 flag_ptr);
950966
951967 // 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
@@ -954,8 +970,8 @@ void BatchNormGradRawKernel(const Context &ctx,
954970 transformed_d_y.template data <T>(),
955971 transformed_x.template data <T>(),
956972 scale.template data <BatchNormParamType<T>>(),
957- d_scale-> data <BatchNormParamType<T>>() ,
958- d_bias-> data <BatchNormParamType<T>>() ,
973+ dscale ,
974+ dbias ,
959975 mean_ptr,
960976 variance_ptr,
961977 C,
@@ -1165,6 +1181,7 @@ void BatchNormGradRawKernel(const Context &ctx,
11651181 paddle::platform::dynload::cudnnDestroyTensorDescriptor (
11661182 bn_param_desc_));
11671183#endif
1184+
11681185 } else {
11691186 const auto *running_mean = mean.get_ptr ();
11701187 const auto *running_var = variance.get_ptr ();
0 commit comments