Skip to content

Commit c50c537

Browse files
committed
fix arithmetic error in backward kernel
2 parents 0108836 + 2ab122a commit c50c537

File tree

2 files changed

+153
-59
lines changed

2 files changed

+153
-59
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,9 @@ op_library(squeeze_op DEPS reshape_op)
273273
op_library(extract_rows_op DEPS memory)
274274
op_library(flatten_op DEPS reshape_op)
275275

276-
277276
if (WITH_GPU)
278277
op_library(conv_op DEPS vol2col depthwise_conv im2col)
278+
op_library(layer_norm_op DEPS cub)
279279
else()
280280
op_library(conv_op DEPS vol2col im2col)
281281
endif()

paddle/fluid/operators/layer_norm_op.cu

Lines changed: 152 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -45,38 +45,55 @@ inline static int GetDesiredBlockDim(int block_dim) {
4545
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
4646
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }
4747

48+
template <typename T>
49+
struct PairForLayerNorm {
50+
__device__ __forceinline__ PairForLayerNorm() {}
51+
__device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
52+
: first_(first), second_(second) {}
53+
54+
T first_;
55+
T second_;
56+
};
57+
58+
template <typename T>
59+
struct PairForLayerNormAddFunctor {
60+
__device__ __forceinline__ PairForLayerNorm<T> operator()(
61+
const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
62+
return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
63+
}
64+
};
65+
4866
template <typename T, int BlockDim>
4967
__global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
5068
T *y, T *mean, T *var, float epsilon,
5169
int feature_size) {
52-
using BlockReduce = cub::BlockReduce<T, BlockDim>;
70+
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
5371
__shared__ typename BlockReduce::TempStorage temp_storage;
5472

5573
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
5674
int end_idx = (blockIdx.x + 1) * feature_size;
5775

58-
// Step 1: Reduce to calculate mean
76+
// Step 1: Reduce to calculate mean and var
5977
T mean_val = static_cast<T>(0);
60-
for (int i = beg_idx; i < end_idx; i += BlockDim) {
61-
mean_val += x[i];
62-
}
63-
mean_val = BlockReduce(temp_storage).Reduce(mean_val, cub::Sum());
64-
if (threadIdx.x == 0) mean[blockIdx.x] = mean_val / feature_size;
65-
__syncthreads();
66-
mean_val = mean[blockIdx.x];
67-
68-
// Step 2: Reduce to calculate var
6978
T var_val = static_cast<T>(0);
7079
for (int i = beg_idx; i < end_idx; i += BlockDim) {
71-
T tmp = x[i] - mean_val;
80+
T tmp = x[i];
81+
mean_val += tmp;
7282
var_val += (tmp * tmp);
7383
}
74-
var_val = BlockReduce(temp_storage).Reduce(var_val, cub::Sum());
75-
if (threadIdx.x == 0) var[blockIdx.x] = var_val / feature_size;
84+
auto pair = BlockReduce(temp_storage)
85+
.Reduce(PairForLayerNorm<T>(mean_val, var_val),
86+
PairForLayerNormAddFunctor<T>());
87+
if (threadIdx.x == 0) {
88+
auto tmp = pair.first_ / feature_size;
89+
mean[blockIdx.x] = tmp;
90+
var[blockIdx.x] = pair.second_ / feature_size - tmp * tmp;
91+
}
7692
__syncthreads();
93+
mean_val = mean[blockIdx.x];
7794
var_val = static_cast<T>(real_sqrt(var[blockIdx.x] + epsilon));
7895

79-
// Step 3: Calculate y
96+
// Step 2: Calculate y
8097
if (scale != nullptr) {
8198
if (bias != nullptr) {
8299
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
@@ -104,27 +121,6 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
104121
}
105122
}
106123

107-
template <typename T>
108-
struct PairForLayerNormBackward {
109-
__device__ __forceinline__ PairForLayerNormBackward() {}
110-
__device__ __forceinline__ PairForLayerNormBackward(const T &first,
111-
const T &second)
112-
: first_(first), second_(second) {}
113-
114-
T first_;
115-
T second_;
116-
};
117-
118-
template <typename T>
119-
struct PairForLayerNormBackwardAddFunctor {
120-
__device__ __forceinline__ PairForLayerNormBackward<T> operator()(
121-
const PairForLayerNormBackward<T> &p1,
122-
const PairForLayerNormBackward<T> &p2) {
123-
return PairForLayerNormBackward<T>(p1.first_ + p2.first_,
124-
p1.second_ + p2.second_);
125-
}
126-
};
127-
128124
// Make sure that d_scale != nullptr && d_bias != nullptr
129125
// Since d_scale != nullptr, scale would not be nullptr
130126
template <typename T, int BlockDim, bool HasDx>
@@ -133,26 +129,28 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
133129
const T *mean, const T *var,
134130
const T *scale, float epsilon,
135131
int batch_size, int feature_size) {
136-
using BlockReduce = cub::BlockReduce<PairForLayerNormBackward<T>, BlockDim>;
132+
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
137133
__shared__ typename BlockReduce::TempStorage temp_storage;
138134

139135
int beg_idx = threadIdx.x * feature_size + blockIdx.x;
140136
int end_idx = batch_size * feature_size + blockIdx.x;
141137
int stride = BlockDim * feature_size;
138+
142139
T d_scale_partial = 0, d_bias_partial = 0;
143140

144141
for (int i = beg_idx; i < end_idx; i += stride) {
145142
int row_idx = i / feature_size;
146143
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
147144
d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
148145
d_bias_partial += d_y[i];
149-
if (HasDx) d_x[i] = d_y[i] * scale[blockIdx.x] / var_val;
146+
if (HasDx) {
147+
d_x[i] = d_y[i] * scale[blockIdx.x] / var_val;
148+
}
150149
}
151150

152-
auto pair =
153-
BlockReduce(temp_storage)
154-
.Reduce(PairForLayerNormBackward<T>(d_scale_partial, d_bias_partial),
155-
PairForLayerNormBackwardAddFunctor<T>());
151+
auto pair = BlockReduce(temp_storage)
152+
.Reduce(PairForLayerNorm<T>(d_scale_partial, d_bias_partial),
153+
PairForLayerNormAddFunctor<T>());
156154

157155
if (threadIdx.x == 0) {
158156
d_scale[blockIdx.x] = pair.first_;
@@ -205,22 +203,90 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
205203
}
206204
}
207205

206+
template <typename T, int BlockDim>
207+
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
208+
const T *mean,
209+
const T *var,
210+
float epsilon,
211+
int feature_size) {
212+
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
213+
__shared__ typename BlockReduce::TempStorage temp_storage;
214+
__shared__ T d_x_reduce_tmp[2];
215+
216+
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
217+
int end_idx = (blockIdx.x + 1) * feature_size;
218+
219+
T block_mean = mean[blockIdx.x];
220+
T block_var = var[blockIdx.x];
221+
T d_x_mean_partial = 0, d_x_var_partial = 0;
222+
for (int i = beg_idx; i < end_idx; i += BlockDim) {
223+
d_x_mean_partial += d_x[i];
224+
d_x_var_partial += d_x[i] * (x[i] - block_mean);
225+
}
226+
227+
auto pair =
228+
BlockReduce(temp_storage)
229+
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
230+
PairForLayerNormAddFunctor<T>());
231+
232+
if (threadIdx.x == 0) {
233+
d_x_reduce_tmp[0] = pair.first_ / feature_size;
234+
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
235+
}
236+
__syncthreads();
237+
238+
d_x_mean_partial = d_x_reduce_tmp[0];
239+
d_x_var_partial = d_x_reduce_tmp[1];
240+
for (int i = beg_idx; i < end_idx; i += BlockDim) {
241+
d_x[i] -= d_x_mean_partial;
242+
d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
243+
}
244+
}
245+
208246
// Here, we only calculate d_x
209-
template <typename T>
210-
__global__ void LayerNormBackwardGradientOnlyX(const T *d_y, T *d_x,
211-
const T *var, const T *scale,
212-
float epsilon, int batch_size,
213-
int feature_size) {
214-
int idx = threadIdx.x + blockIdx.x * blockDim.x;
215-
if (idx < batch_size * feature_size) {
216-
int row_idx = idx / feature_size;
217-
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
247+
template <typename T, int BlockDim>
248+
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
249+
T *d_x, const T *mean,
250+
const T *var, const T *scale,
251+
float epsilon,
252+
int feature_size) {
253+
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
254+
__shared__ typename BlockReduce::TempStorage temp_storage;
255+
__shared__ T d_x_reduce_tmp[2];
256+
257+
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
258+
int end_idx = (blockIdx.x + 1) * feature_size;
259+
260+
T block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
261+
T d_x_mean_partial = 0, d_x_var_partial = 0;
262+
for (int i = beg_idx; i < end_idx; i += BlockDim) {
263+
auto var_val = static_cast<T>(real_sqrt(block_var + epsilon));
218264
if (scale != nullptr) {
219-
int col_idx = idx % feature_size;
220-
d_x[idx] = d_y[idx] * scale[col_idx] / var_val;
265+
int col_idx = i % feature_size;
266+
d_x[i] = d_y[i] * scale[col_idx] / var_val;
221267
} else {
222-
d_x[idx] = d_y[idx] / var_val;
268+
d_x[i] = d_y[i] / var_val;
223269
}
270+
d_x_mean_partial += d_x[i];
271+
d_x_var_partial += d_x[i] * (x[i] - block_mean);
272+
}
273+
274+
auto pair =
275+
BlockReduce(temp_storage)
276+
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
277+
PairForLayerNormAddFunctor<T>());
278+
279+
if (threadIdx.x == 0) {
280+
d_x_reduce_tmp[0] = pair.first_ / feature_size;
281+
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
282+
}
283+
__syncthreads();
284+
285+
d_x_mean_partial = d_x_reduce_tmp[0];
286+
d_x_var_partial = d_x_reduce_tmp[1];
287+
for (int i = beg_idx; i < end_idx; i += BlockDim) {
288+
d_x[i] -= d_x_mean_partial;
289+
d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
224290
}
225291
}
226292

@@ -263,6 +329,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
263329
T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0,
264330
stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
265331
feature_size);
332+
333+
if (d_x != nullptr) {
334+
switch (GetDesiredBlockDim(feature_size)) {
335+
FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
336+
T, kBlockDim><<<1, kBlockDim, 0, stream>>>(
337+
x, d_x, mean, var, epsilon, feature_size));
338+
}
339+
}
266340
return;
267341
}
268342

@@ -296,10 +370,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
296370
}
297371
break;
298372
case 4: // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
299-
LayerNormBackwardGradientOnlyX<
300-
T><<<(batch_size * feature_size + kMaxBlockDim - 1) / kMaxBlockDim,
301-
kMaxBlockDim, 0, stream>>>(d_y, d_x, var, scale, epsilon,
302-
batch_size, feature_size);
373+
switch (GetDesiredBlockDim(feature_size)) {
374+
FIXED_BLOCK_DIM_CASE(
375+
LayerNormBackwardGradientOnlyDX<
376+
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
377+
x, d_y, d_x, mean, var, scale, epsilon, feature_size));
378+
}
303379
break;
304380
case 5: // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
305381
switch (block_dim) {
@@ -309,6 +385,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
309385
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
310386
feature_size));
311387
}
388+
switch (GetDesiredBlockDim(feature_size)) {
389+
FIXED_BLOCK_DIM_CASE(
390+
LayerNormBackwardPostProcessToCalculateDX<
391+
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
392+
x, d_x, mean, var, epsilon, feature_size));
393+
}
312394
break;
313395
case 6: // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
314396
switch (block_dim) {
@@ -318,6 +400,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
318400
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size,
319401
feature_size));
320402
}
403+
switch (GetDesiredBlockDim(feature_size)) {
404+
FIXED_BLOCK_DIM_CASE(
405+
LayerNormBackwardPostProcessToCalculateDX<
406+
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
407+
x, d_x, mean, var, epsilon, feature_size));
408+
}
321409
break;
322410
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
323411
switch (block_dim) {
@@ -327,6 +415,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
327415
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
328416
batch_size, feature_size));
329417
}
418+
switch (GetDesiredBlockDim(feature_size)) {
419+
FIXED_BLOCK_DIM_CASE(
420+
LayerNormBackwardPostProcessToCalculateDX<
421+
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
422+
x, d_x, mean, var, epsilon, feature_size));
423+
}
330424
break;
331425
default:
332426
break;

0 commit comments

Comments
 (0)