@@ -45,38 +45,55 @@ inline static int GetDesiredBlockDim(int block_dim) {
45
45
static __device__ __forceinline__ float real_sqrt (float x) { return sqrtf (x); }
46
46
static __device__ __forceinline__ double real_sqrt (double x) { return sqrt (x); }
47
47
48
+ template <typename T>
49
+ struct PairForLayerNorm {
50
+ __device__ __forceinline__ PairForLayerNorm () {}
51
+ __device__ __forceinline__ PairForLayerNorm (const T &first, const T &second)
52
+ : first_(first), second_(second) {}
53
+
54
+ T first_;
55
+ T second_;
56
+ };
57
+
58
+ template <typename T>
59
+ struct PairForLayerNormAddFunctor {
60
+ __device__ __forceinline__ PairForLayerNorm<T> operator ()(
61
+ const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
62
+ return PairForLayerNorm<T>(p1.first_ + p2.first_ , p1.second_ + p2.second_ );
63
+ }
64
+ };
65
+
48
66
template <typename T, int BlockDim>
49
67
__global__ void LayerNormForward (const T *x, const T *scale, const T *bias,
50
68
T *y, T *mean, T *var, float epsilon,
51
69
int feature_size) {
52
- using BlockReduce = cub::BlockReduce<T , BlockDim>;
70
+ using BlockReduce = cub::BlockReduce<PairForLayerNorm<T> , BlockDim>;
53
71
__shared__ typename BlockReduce::TempStorage temp_storage;
54
72
55
73
int beg_idx = blockIdx .x * feature_size + threadIdx .x ;
56
74
int end_idx = (blockIdx .x + 1 ) * feature_size;
57
75
58
- // Step 1: Reduce to calculate mean
76
+ // Step 1: Reduce to calculate mean and var
59
77
T mean_val = static_cast <T>(0 );
60
- for (int i = beg_idx; i < end_idx; i += BlockDim) {
61
- mean_val += x[i];
62
- }
63
- mean_val = BlockReduce (temp_storage).Reduce (mean_val, cub::Sum ());
64
- if (threadIdx .x == 0 ) mean[blockIdx .x ] = mean_val / feature_size;
65
- __syncthreads ();
66
- mean_val = mean[blockIdx .x ];
67
-
68
- // Step 2: Reduce to calculate var
69
78
T var_val = static_cast <T>(0 );
70
79
for (int i = beg_idx; i < end_idx; i += BlockDim) {
71
- T tmp = x[i] - mean_val;
80
+ T tmp = x[i];
81
+ mean_val += tmp;
72
82
var_val += (tmp * tmp);
73
83
}
74
- var_val = BlockReduce (temp_storage).Reduce (var_val, cub::Sum ());
75
- if (threadIdx .x == 0 ) var[blockIdx .x ] = var_val / feature_size;
84
+ auto pair = BlockReduce (temp_storage)
85
+ .Reduce (PairForLayerNorm<T>(mean_val, var_val),
86
+ PairForLayerNormAddFunctor<T>());
87
+ if (threadIdx .x == 0 ) {
88
+ auto tmp = pair.first_ / feature_size;
89
+ mean[blockIdx .x ] = tmp;
90
+ var[blockIdx .x ] = pair.second_ / feature_size - tmp * tmp;
91
+ }
76
92
__syncthreads ();
93
+ mean_val = mean[blockIdx .x ];
77
94
var_val = static_cast <T>(real_sqrt (var[blockIdx .x ] + epsilon));
78
95
79
- // Step 3 : Calculate y
96
+ // Step 2 : Calculate y
80
97
if (scale != nullptr ) {
81
98
if (bias != nullptr ) {
82
99
for (int i = beg_idx, j = threadIdx .x ; i < end_idx;
@@ -104,27 +121,6 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
104
121
}
105
122
}
106
123
107
- template <typename T>
108
- struct PairForLayerNormBackward {
109
- __device__ __forceinline__ PairForLayerNormBackward () {}
110
- __device__ __forceinline__ PairForLayerNormBackward (const T &first,
111
- const T &second)
112
- : first_(first), second_(second) {}
113
-
114
- T first_;
115
- T second_;
116
- };
117
-
118
- template <typename T>
119
- struct PairForLayerNormBackwardAddFunctor {
120
- __device__ __forceinline__ PairForLayerNormBackward<T> operator ()(
121
- const PairForLayerNormBackward<T> &p1,
122
- const PairForLayerNormBackward<T> &p2) {
123
- return PairForLayerNormBackward<T>(p1.first_ + p2.first_ ,
124
- p1.second_ + p2.second_ );
125
- }
126
- };
127
-
128
124
// Make sure that d_scale != nullptr && d_bias != nullptr
129
125
// Since d_scale != nullptr, scale would not be nullptr
130
126
template <typename T, int BlockDim, bool HasDx>
@@ -133,26 +129,28 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
133
129
const T *mean, const T *var,
134
130
const T *scale, float epsilon,
135
131
int batch_size, int feature_size) {
136
- using BlockReduce = cub::BlockReduce<PairForLayerNormBackward <T>, BlockDim>;
132
+ using BlockReduce = cub::BlockReduce<PairForLayerNorm <T>, BlockDim>;
137
133
__shared__ typename BlockReduce::TempStorage temp_storage;
138
134
139
135
int beg_idx = threadIdx .x * feature_size + blockIdx .x ;
140
136
int end_idx = batch_size * feature_size + blockIdx .x ;
141
137
int stride = BlockDim * feature_size;
138
+
142
139
T d_scale_partial = 0 , d_bias_partial = 0 ;
143
140
144
141
for (int i = beg_idx; i < end_idx; i += stride) {
145
142
int row_idx = i / feature_size;
146
143
auto var_val = static_cast <T>(real_sqrt (var[row_idx] + epsilon));
147
144
d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
148
145
d_bias_partial += d_y[i];
149
- if (HasDx) d_x[i] = d_y[i] * scale[blockIdx .x ] / var_val;
146
+ if (HasDx) {
147
+ d_x[i] = d_y[i] * scale[blockIdx .x ] / var_val;
148
+ }
150
149
}
151
150
152
- auto pair =
153
- BlockReduce (temp_storage)
154
- .Reduce (PairForLayerNormBackward<T>(d_scale_partial, d_bias_partial),
155
- PairForLayerNormBackwardAddFunctor<T>());
151
+ auto pair = BlockReduce (temp_storage)
152
+ .Reduce (PairForLayerNorm<T>(d_scale_partial, d_bias_partial),
153
+ PairForLayerNormAddFunctor<T>());
156
154
157
155
if (threadIdx .x == 0 ) {
158
156
d_scale[blockIdx .x ] = pair.first_ ;
@@ -205,22 +203,90 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
205
203
}
206
204
}
207
205
206
+ template <typename T, int BlockDim>
207
+ __global__ void LayerNormBackwardPostProcessToCalculateDX (const T *x, T *d_x,
208
+ const T *mean,
209
+ const T *var,
210
+ float epsilon,
211
+ int feature_size) {
212
+ using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
213
+ __shared__ typename BlockReduce::TempStorage temp_storage;
214
+ __shared__ T d_x_reduce_tmp[2 ];
215
+
216
+ int beg_idx = blockIdx .x * feature_size + threadIdx .x ;
217
+ int end_idx = (blockIdx .x + 1 ) * feature_size;
218
+
219
+ T block_mean = mean[blockIdx .x ];
220
+ T block_var = var[blockIdx .x ];
221
+ T d_x_mean_partial = 0 , d_x_var_partial = 0 ;
222
+ for (int i = beg_idx; i < end_idx; i += BlockDim) {
223
+ d_x_mean_partial += d_x[i];
224
+ d_x_var_partial += d_x[i] * (x[i] - block_mean);
225
+ }
226
+
227
+ auto pair =
228
+ BlockReduce (temp_storage)
229
+ .Reduce (PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
230
+ PairForLayerNormAddFunctor<T>());
231
+
232
+ if (threadIdx .x == 0 ) {
233
+ d_x_reduce_tmp[0 ] = pair.first_ / feature_size;
234
+ d_x_reduce_tmp[1 ] = pair.second_ / (feature_size * (block_var + epsilon));
235
+ }
236
+ __syncthreads ();
237
+
238
+ d_x_mean_partial = d_x_reduce_tmp[0 ];
239
+ d_x_var_partial = d_x_reduce_tmp[1 ];
240
+ for (int i = beg_idx; i < end_idx; i += BlockDim) {
241
+ d_x[i] -= d_x_mean_partial;
242
+ d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
243
+ }
244
+ }
245
+
208
246
// Here, we only calculate d_x
209
- template <typename T>
210
- __global__ void LayerNormBackwardGradientOnlyX (const T *d_y, T *d_x,
211
- const T *var, const T *scale,
212
- float epsilon, int batch_size,
213
- int feature_size) {
214
- int idx = threadIdx .x + blockIdx .x * blockDim .x ;
215
- if (idx < batch_size * feature_size) {
216
- int row_idx = idx / feature_size;
217
- auto var_val = static_cast <T>(real_sqrt (var[row_idx] + epsilon));
247
+ template <typename T, int BlockDim>
248
+ __global__ void LayerNormBackwardGradientOnlyDX (const T *x, const T *d_y,
249
+ T *d_x, const T *mean,
250
+ const T *var, const T *scale,
251
+ float epsilon,
252
+ int feature_size) {
253
+ using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
254
+ __shared__ typename BlockReduce::TempStorage temp_storage;
255
+ __shared__ T d_x_reduce_tmp[2 ];
256
+
257
+ int beg_idx = blockIdx .x * feature_size + threadIdx .x ;
258
+ int end_idx = (blockIdx .x + 1 ) * feature_size;
259
+
260
+ T block_mean = mean[blockIdx .x ], block_var = var[blockIdx .x ];
261
+ T d_x_mean_partial = 0 , d_x_var_partial = 0 ;
262
+ for (int i = beg_idx; i < end_idx; i += BlockDim) {
263
+ auto var_val = static_cast <T>(real_sqrt (block_var + epsilon));
218
264
if (scale != nullptr ) {
219
- int col_idx = idx % feature_size;
220
- d_x[idx ] = d_y[idx ] * scale[col_idx] / var_val;
265
+ int col_idx = i % feature_size;
266
+ d_x[i ] = d_y[i ] * scale[col_idx] / var_val;
221
267
} else {
222
- d_x[idx ] = d_y[idx ] / var_val;
268
+ d_x[i ] = d_y[i ] / var_val;
223
269
}
270
+ d_x_mean_partial += d_x[i];
271
+ d_x_var_partial += d_x[i] * (x[i] - block_mean);
272
+ }
273
+
274
+ auto pair =
275
+ BlockReduce (temp_storage)
276
+ .Reduce (PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
277
+ PairForLayerNormAddFunctor<T>());
278
+
279
+ if (threadIdx .x == 0 ) {
280
+ d_x_reduce_tmp[0 ] = pair.first_ / feature_size;
281
+ d_x_reduce_tmp[1 ] = pair.second_ / (feature_size * (block_var + epsilon));
282
+ }
283
+ __syncthreads ();
284
+
285
+ d_x_mean_partial = d_x_reduce_tmp[0 ];
286
+ d_x_var_partial = d_x_reduce_tmp[1 ];
287
+ for (int i = beg_idx; i < end_idx; i += BlockDim) {
288
+ d_x[i] -= d_x_mean_partial;
289
+ d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
224
290
}
225
291
}
226
292
@@ -263,6 +329,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
263
329
T><<<(feature_size + kMaxBlockDim - 1 ) / kMaxBlockDim , kMaxBlockDim , 0 ,
264
330
stream>>> (x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
265
331
feature_size);
332
+
333
+ if (d_x != nullptr ) {
334
+ switch (GetDesiredBlockDim (feature_size)) {
335
+ FIXED_BLOCK_DIM_CASE (LayerNormBackwardPostProcessToCalculateDX<
336
+ T, kBlockDim ><<<1 , kBlockDim , 0 , stream>>> (
337
+ x, d_x, mean, var, epsilon, feature_size));
338
+ }
339
+ }
266
340
return ;
267
341
}
268
342
@@ -296,10 +370,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
296
370
}
297
371
break ;
298
372
case 4 : // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
299
- LayerNormBackwardGradientOnlyX<
300
- T><<<(batch_size * feature_size + kMaxBlockDim - 1 ) / kMaxBlockDim ,
301
- kMaxBlockDim , 0 , stream>>> (d_y, d_x, var, scale, epsilon,
302
- batch_size, feature_size);
373
+ switch (GetDesiredBlockDim (feature_size)) {
374
+ FIXED_BLOCK_DIM_CASE (
375
+ LayerNormBackwardGradientOnlyDX<
376
+ T, kBlockDim ><<<batch_size, kBlockDim , 0 , stream>>> (
377
+ x, d_y, d_x, mean, var, scale, epsilon, feature_size));
378
+ }
303
379
break ;
304
380
case 5 : // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
305
381
switch (block_dim) {
@@ -309,6 +385,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
309
385
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
310
386
feature_size));
311
387
}
388
+ switch (GetDesiredBlockDim (feature_size)) {
389
+ FIXED_BLOCK_DIM_CASE (
390
+ LayerNormBackwardPostProcessToCalculateDX<
391
+ T, kBlockDim ><<<batch_size, kBlockDim , 0 , stream>>> (
392
+ x, d_x, mean, var, epsilon, feature_size));
393
+ }
312
394
break ;
313
395
case 6 : // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
314
396
switch (block_dim) {
@@ -318,6 +400,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
318
400
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
319
401
feature_size));
320
402
}
403
+ switch (GetDesiredBlockDim (feature_size)) {
404
+ FIXED_BLOCK_DIM_CASE (
405
+ LayerNormBackwardPostProcessToCalculateDX<
406
+ T, kBlockDim ><<<batch_size, kBlockDim , 0 , stream>>> (
407
+ x, d_x, mean, var, epsilon, feature_size));
408
+ }
321
409
break ;
322
410
case 7 : // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
323
411
switch (block_dim) {
@@ -327,6 +415,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
327
415
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
328
416
batch_size, feature_size));
329
417
}
418
+ switch (GetDesiredBlockDim (feature_size)) {
419
+ FIXED_BLOCK_DIM_CASE (
420
+ LayerNormBackwardPostProcessToCalculateDX<
421
+ T, kBlockDim ><<<batch_size, kBlockDim , 0 , stream>>> (
422
+ x, d_x, mean, var, epsilon, feature_size));
423
+ }
330
424
break ;
331
425
default :
332
426
break ;
0 commit comments