@@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
72
72
}
73
73
}
74
74
75
+ template <typename T>
76
+ static __global__ void InverseVariance (const BatchNormParamType<T> *variance,
77
+ const double epsilon,
78
+ const int C,
79
+ BatchNormParamType<T> *inv_variance) {
80
+ int tid = threadIdx .x + blockIdx .x * blockDim .x ;
81
+ if (tid < C) {
82
+ inv_variance[tid] = 1 / sqrt (variance[tid] + epsilon);
83
+ }
84
+ }
85
+
86
+ template <typename T, phi::DataLayout layout>
87
+ static __global__ void BN1DForwardInference (
88
+ const T *x,
89
+ const BatchNormParamType<T> *mean,
90
+ const BatchNormParamType<T> *inv_variance,
91
+ const BatchNormParamType<T> *scale,
92
+ const BatchNormParamType<T> *bias,
93
+ const int C,
94
+ const int N,
95
+ const int HxW,
96
+ const double epsilon,
97
+ T *y) {
98
+ int gid = blockIdx .x * blockDim .x + threadIdx .x ;
99
+ int stride = blockDim .x * gridDim .x ;
100
+ int num = N * C * HxW;
101
+ for (int i = gid; i < num; i += stride) {
102
+ const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
103
+ BatchNormParamType<T> x_sub_mean =
104
+ static_cast <BatchNormParamType<T>>(x[i]) - mean[c];
105
+ y[i] = static_cast <T>(scale[c] * x_sub_mean * inv_variance[c] + bias[c]);
106
+ }
107
+ }
108
+
75
109
template <typename T, int BlockDim, phi::DataLayout layout>
76
110
static __global__ LAUNCH_BOUNDS (BlockDim) void BNForwardTraining(
77
111
const T *x,
@@ -691,9 +725,6 @@ void BatchNormKernel(const Context &ctx,
691
725
692
726
auto handle = ctx.cudnn_handle ();
693
727
694
- const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240 ;
695
- const size_t CUDNN_SPATIAL_THRESHOLD = 880801 ;
696
-
697
728
// Now, depending on whether we are running test or not, we have two paths.
698
729
// It is training mode when it's not reference AND not using pre-trained
699
730
// model.
@@ -797,8 +828,8 @@ void BatchNormKernel(const Context &ctx,
797
828
// epsilon));
798
829
#else
799
830
const bool use_native_kernel =
800
- (( x_dims.size () == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
801
- (x_dims.size () == 3 && N >= CUDNN_SPATIAL_THRESHOLD ));
831
+ (x_dims.size () == 2 ||
832
+ (x_dims.size () == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL ));
802
833
if (use_native_kernel) {
803
834
const int block_size = 256 ;
804
835
const int grid_size = (N * C * H * W * D + block_size - 1 ) / block_size;
@@ -816,18 +847,43 @@ void BatchNormKernel(const Context &ctx,
816
847
epsilon,
817
848
transformed_y.template data <T>());
818
849
} else {
819
- BNForwardInference<T, DataLayout::kNHWC >
820
- <<<grid_size, block_size, 0 , ctx.stream()>>> (
821
- transformed_x.template data <T>(),
822
- est_mean->template data <BatchNormParamType<T>>(),
823
- est_var->template data <BatchNormParamType<T>>(),
824
- scale.template data <BatchNormParamType<T>>(),
825
- bias.template data <BatchNormParamType<T>>(),
826
- C,
827
- N,
828
- H * W * D,
829
- epsilon,
830
- transformed_y.template data <T>());
850
+ if (x_dims.size () == 2 ) {
851
+ DenseTensor inv_var = phi::Empty<BatchNormParamType<T>>(ctx, {C});
852
+ auto *inv_var_ptr = inv_var.data <BatchNormParamType<T>>();
853
+ const int threads = 512 > C ? C : 512 ;
854
+ const int blocks = (C + 511 ) / 512 ;
855
+ InverseVariance<T><<<blocks, threads>>> (
856
+ est_var->template data <BatchNormParamType<T>>(),
857
+ epsilon,
858
+ C,
859
+ inv_var_ptr);
860
+ BN1DForwardInference<T, DataLayout::kNHWC >
861
+ <<<grid_size, block_size, 0 , ctx.stream()>>> (
862
+ transformed_x.template data <T>(),
863
+ est_mean->template data <BatchNormParamType<T>>(),
864
+ // est_var->template data<BatchNormParamType<T>>(),
865
+ inv_var_ptr,
866
+ scale.template data <BatchNormParamType<T>>(),
867
+ bias.template data <BatchNormParamType<T>>(),
868
+ C,
869
+ N,
870
+ H * W * D,
871
+ epsilon,
872
+ transformed_y.template data <T>());
873
+ } else {
874
+ BNForwardInference<T, DataLayout::kNHWC >
875
+ <<<grid_size, block_size, 0 , ctx.stream()>>> (
876
+ transformed_x.template data <T>(),
877
+ est_mean->template data <BatchNormParamType<T>>(),
878
+ est_var->template data <BatchNormParamType<T>>(),
879
+ scale.template data <BatchNormParamType<T>>(),
880
+ bias.template data <BatchNormParamType<T>>(),
881
+ C,
882
+ N,
883
+ H * W * D,
884
+ epsilon,
885
+ transformed_y.template data <T>());
886
+ }
831
887
}
832
888
} else {
833
889
PADDLE_ENFORCE_GPU_SUCCESS (
@@ -949,7 +1005,7 @@ void BatchNormKernel(const Context &ctx,
949
1005
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
950
1006
const bool use_native_kernel =
951
1007
((x_dims.size () == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
952
- (x_dims.size () == 3 && N >= CUDNN_SPATIAL_THRESHOLD ));
1008
+ (x_dims.size () == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN ));
953
1009
if (use_native_kernel) {
954
1010
dim3 block;
955
1011
dim3 grid;
0 commit comments