@@ -855,7 +855,8 @@ void BatchNormGradRawKernel(const Context &ctx,
855
855
}
856
856
// CUDNN only support small batch size
857
857
bool use_native_nhwc =
858
- d_x ? (x_dims.size () == 4 && compute_format == DataLayout::kNHWC )
858
+ d_x ? (x_dims.size () == 4 && compute_format == DataLayout::kNHWC &&
859
+ H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL)
859
860
: false ;
860
861
const bool use_native_kernel =
861
862
((x_dims.size () == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
@@ -933,6 +934,21 @@ void BatchNormGradRawKernel(const Context &ctx,
933
934
flag_ptr);
934
935
}
935
936
// 2. reduce_sum(x, dy, mean) => dscale, dbias
937
+ BatchNormParamType<T> *dscale = nullptr ;
938
+ BatchNormParamType<T> *dbias = nullptr ;
939
+ bool with_scale = false ;
940
+ if (d_scale && d_bias) {
941
+ dscale = ctx.template Alloc <BatchNormParamType<T>>(d_scale);
942
+ dbias = ctx.template Alloc <BatchNormParamType<T>>(d_bias);
943
+ } else {
944
+ DenseTensor dscale_mem =
945
+ phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
946
+ DenseTensor dbias_mem =
947
+ phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
948
+ dscale = dscale_mem.data <BatchNormParamType<T>>();
949
+ dbias = dbias_mem.data <BatchNormParamType<T>>();
950
+ }
951
+
936
952
BNBackward2DChannelLastStage2<T, block_size>
937
953
<<<grid, block, 0 , ctx.stream()>>> (
938
954
transformed_d_y.template data <T>(),
@@ -944,8 +960,8 @@ void BatchNormGradRawKernel(const Context &ctx,
944
960
H * W * D,
945
961
epsilon,
946
962
block_data_ptr,
947
- ctx. template Alloc <BatchNormParamType<T>>(d_scale) ,
948
- ctx. template Alloc <BatchNormParamType<T>>(d_bias) ,
963
+ dscale ,
964
+ dbias ,
949
965
flag_ptr);
950
966
951
967
// 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
@@ -954,8 +970,8 @@ void BatchNormGradRawKernel(const Context &ctx,
954
970
transformed_d_y.template data <T>(),
955
971
transformed_x.template data <T>(),
956
972
scale.template data <BatchNormParamType<T>>(),
957
- d_scale-> data <BatchNormParamType<T>>() ,
958
- d_bias-> data <BatchNormParamType<T>>() ,
973
+ dscale ,
974
+ dbias ,
959
975
mean_ptr,
960
976
variance_ptr,
961
977
C,
@@ -1165,6 +1181,7 @@ void BatchNormGradRawKernel(const Context &ctx,
1165
1181
paddle::platform::dynload::cudnnDestroyTensorDescriptor (
1166
1182
bn_param_desc_));
1167
1183
#endif
1184
+
1168
1185
} else {
1169
1186
const auto *running_mean = mean.get_ptr ();
1170
1187
const auto *running_var = variance.get_ptr ();
0 commit comments