@@ -67,27 +67,27 @@ template <typename T, int BlockDim>
67
67
__global__ void LayerNormForward (const T *x, const T *scale, const T *bias,
68
68
T *y, T *mean, T *var, float epsilon,
69
69
int feature_size) {
70
- using BlockReduce = cub::BlockReduce<PairForLayerNorm<T >, BlockDim>;
70
+ using BlockReduce = cub::BlockReduce<PairForLayerNorm<double >, BlockDim>;
71
71
__shared__ typename BlockReduce::TempStorage temp_storage;
72
72
73
73
int beg_idx = blockIdx .x * feature_size + threadIdx .x ;
74
74
int end_idx = (blockIdx .x + 1 ) * feature_size;
75
75
76
76
// Step 1: Reduce to calculate mean and var
77
- T mean_val = static_cast <T>( 0 ) ;
78
- T var_val = static_cast <T>( 0 ) ;
77
+ double mean_val = 0 ;
78
+ double var_val = 0 ;
79
79
for (int i = beg_idx; i < end_idx; i += BlockDim) {
80
80
T tmp = x[i];
81
81
mean_val += tmp;
82
82
var_val += (tmp * tmp);
83
83
}
84
84
auto pair = BlockReduce (temp_storage)
85
- .Reduce (PairForLayerNorm<T >(mean_val, var_val),
86
- PairForLayerNormAddFunctor<T >());
85
+ .Reduce (PairForLayerNorm<double >(mean_val, var_val),
86
+ PairForLayerNormAddFunctor<double >());
87
87
if (threadIdx .x == 0 ) {
88
88
auto tmp = pair.first_ / feature_size;
89
- mean[blockIdx .x ] = tmp;
90
- var[blockIdx .x ] = pair.second_ / feature_size - tmp * tmp;
89
+ mean[blockIdx .x ] = static_cast <T>( tmp) ;
90
+ var[blockIdx .x ] = static_cast <T>( pair.second_ / feature_size - tmp * tmp) ;
91
91
}
92
92
__syncthreads ();
93
93
mean_val = mean[blockIdx .x ];
0 commit comments