Skip to content

Commit 5faaf3e

Browse files
authored
[BIG tensor] fix cuda error of layer_norm (#74404)
1 parent 0c33056 commit 5faaf3e

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

paddle/phi/kernels/funcs/layer_norm_impl.cu.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,9 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block,
550550
for (int k = 0; k < VPT; ++k) {
551551
const int i2 = i2_off + k;
552552
const int64_t load_idx = i1 * n2 + i2;
553-
const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
553+
const int64_t write_idx =
554+
static_cast<int64_t>(thr_load_row_off) * row_stride + thr_load_col_off +
555+
k;
554556
if (i2 < n2) {
555557
U curr_input = static_cast<U>(input[load_idx]);
556558
U curr_dout = static_cast<U>(dout[load_idx]);

paddle/phi/kernels/gpu/layer_norm_kernel.cu

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ struct LayerNormDataWriter {
159159
temp_dst[j] = static_cast<T>((buffer[i * VecSize + j] - row_mean) *
160160
row_inv_var);
161161
}
162-
v_dst[threadIdx.x + blockDim.x * i] = temp_dst;
162+
v_dst[threadIdx.x + static_cast<int64_t>(blockDim.x) * i] = temp_dst;
163163
}
164164
} else {
165165
const VecScaleT *__restrict__ v_scale =
@@ -168,7 +168,7 @@ struct LayerNormDataWriter {
168168
reinterpret_cast<const VecScaleT *__restrict__>(bias);
169169
if (valid_scale && valid_bias) {
170170
for (int i = 0; i < write_times; ++i) {
171-
int idx = threadIdx.x + blockDim.x * i;
171+
int64_t idx = threadIdx.x + static_cast<int64_t>(blockDim.x) * i;
172172
VecT temp_dst;
173173
VecScaleT temp_v_scale = v_scale[idx];
174174
VecScaleT temp_v_bias = v_bias[idx];
@@ -184,7 +184,7 @@ struct LayerNormDataWriter {
184184
} else {
185185
if (valid_scale) {
186186
for (int i = 0; i < write_times; ++i) {
187-
int idx = threadIdx.x + blockDim.x * i;
187+
int64_t idx = threadIdx.x + static_cast<int64_t>(blockDim.x) * i;
188188
VecT temp_dst;
189189
VecScaleT temp_v_scale = v_scale[idx];
190190
#pragma unroll
@@ -232,27 +232,27 @@ struct LayerNormDataWriter<T, U, IsSameType, 1> {
232232
if ((!valid_scale) && (!valid_bias)) {
233233
if (threadIdx.x < last_tid_idx) {
234234
for (int i = 0; i < cols_this_thread; ++i) {
235-
row_dst[threadIdx.x + last_tid_idx * i] =
235+
row_dst[threadIdx.x + static_cast<int64_t>(last_tid_idx) * i] =
236236
(buffer[i] - row_mean) * row_inv_var;
237237
}
238238
} else {
239239
for (int i = 0; i < cols_this_thread; ++i) {
240-
row_dst[last_tid_idx * write_times + i] =
240+
row_dst[static_cast<int64_t>(last_tid_idx) * write_times + i] =
241241
(buffer[i] - row_mean) * row_inv_var;
242242
}
243243
}
244244
} else if (valid_scale && valid_bias) {
245245
if (threadIdx.x < last_tid_idx) {
246246
for (int i = 0; i < cols_this_thread; ++i) {
247-
int idx = threadIdx.x + last_tid_idx * i;
247+
int64_t idx = threadIdx.x + static_cast<int64_t>(last_tid_idx) * i;
248248
row_dst[idx] =
249249
static_cast<T>(static_cast<U>(scale[idx]) *
250250
(buffer[i] - row_mean) * row_inv_var +
251251
static_cast<U>(bias[idx]));
252252
}
253253
} else {
254254
for (int i = 0; i < cols_this_thread; ++i) {
255-
int idx = last_tid_idx * write_times + i;
255+
int64_t idx = static_cast<int64_t>(last_tid_idx) * write_times + i;
256256
row_dst[idx] =
257257
static_cast<T>(static_cast<U>(scale[idx]) *
258258
(buffer[i] - row_mean) * row_inv_var +
@@ -263,27 +263,27 @@ struct LayerNormDataWriter<T, U, IsSameType, 1> {
263263
if (valid_scale) {
264264
if (threadIdx.x < last_tid_idx) {
265265
for (int i = 0; i < cols_this_thread; ++i) {
266-
int idx = threadIdx.x + last_tid_idx * i;
266+
int64_t idx = threadIdx.x + static_cast<int64_t>(last_tid_idx) * i;
267267
row_dst[idx] = static_cast<T>(static_cast<U>(scale[idx]) *
268268
(buffer[i] - row_mean) * row_inv_var);
269269
}
270270
} else {
271271
for (int i = 0; i < cols_this_thread; ++i) {
272-
int idx = last_tid_idx * write_times + i;
272+
int64_t idx = static_cast<int64_t>(last_tid_idx) * write_times + i;
273273
row_dst[idx] = static_cast<T>(static_cast<U>(scale[idx]) *
274274
(buffer[i] - row_mean) * row_inv_var);
275275
}
276276
}
277277
} else {
278278
if (threadIdx.x < last_tid_idx) {
279279
for (int i = 0; i < cols_this_thread; ++i) {
280-
int idx = threadIdx.x + last_tid_idx * i;
280+
int64_t idx = threadIdx.x + static_cast<int64_t>(last_tid_idx) * i;
281281
row_dst[idx] = static_cast<T>((buffer[i] - row_mean) * row_inv_var +
282282
static_cast<U>(bias[idx]));
283283
}
284284
} else {
285285
for (int i = 0; i < cols_this_thread; ++i) {
286-
int idx = last_tid_idx * write_times + i;
286+
int64_t idx = static_cast<int64_t>(last_tid_idx) * write_times + i;
287287
row_dst[idx] = static_cast<T>((buffer[i] - row_mean) * row_inv_var +
288288
static_cast<U>(bias[idx]));
289289
}

0 commit comments

Comments
 (0)