@@ -159,7 +159,7 @@ struct LayerNormDataWriter {
159
159
temp_dst[j] = static_cast <T>((buffer[i * VecSize + j] - row_mean) *
160
160
row_inv_var);
161
161
}
162
- v_dst[threadIdx .x + blockDim .x * i] = temp_dst;
162
+ v_dst[threadIdx .x + static_cast < int64_t >( blockDim .x ) * i] = temp_dst;
163
163
}
164
164
} else {
165
165
const VecScaleT *__restrict__ v_scale =
@@ -168,7 +168,7 @@ struct LayerNormDataWriter {
168
168
reinterpret_cast <const VecScaleT *__restrict__ >(bias);
169
169
if (valid_scale && valid_bias) {
170
170
for (int i = 0 ; i < write_times; ++i) {
171
- int idx = threadIdx .x + blockDim .x * i;
171
+ int64_t idx = threadIdx .x + static_cast < int64_t >( blockDim .x ) * i;
172
172
VecT temp_dst;
173
173
VecScaleT temp_v_scale = v_scale[idx];
174
174
VecScaleT temp_v_bias = v_bias[idx];
@@ -184,7 +184,7 @@ struct LayerNormDataWriter {
184
184
} else {
185
185
if (valid_scale) {
186
186
for (int i = 0 ; i < write_times; ++i) {
187
- int idx = threadIdx .x + blockDim .x * i;
187
+ int64_t idx = threadIdx .x + static_cast < int64_t >( blockDim .x ) * i;
188
188
VecT temp_dst;
189
189
VecScaleT temp_v_scale = v_scale[idx];
190
190
#pragma unroll
@@ -232,27 +232,27 @@ struct LayerNormDataWriter<T, U, IsSameType, 1> {
232
232
if ((!valid_scale) && (!valid_bias)) {
233
233
if (threadIdx .x < last_tid_idx) {
234
234
for (int i = 0 ; i < cols_this_thread; ++i) {
235
- row_dst[threadIdx .x + last_tid_idx * i] =
235
+ row_dst[threadIdx .x + static_cast < int64_t >( last_tid_idx) * i] =
236
236
(buffer[i] - row_mean) * row_inv_var;
237
237
}
238
238
} else {
239
239
for (int i = 0 ; i < cols_this_thread; ++i) {
240
- row_dst[last_tid_idx * write_times + i] =
240
+ row_dst[static_cast < int64_t >( last_tid_idx) * write_times + i] =
241
241
(buffer[i] - row_mean) * row_inv_var;
242
242
}
243
243
}
244
244
} else if (valid_scale && valid_bias) {
245
245
if (threadIdx .x < last_tid_idx) {
246
246
for (int i = 0 ; i < cols_this_thread; ++i) {
247
- int idx = threadIdx .x + last_tid_idx * i;
247
+ int64_t idx = threadIdx .x + static_cast < int64_t >( last_tid_idx) * i;
248
248
row_dst[idx] =
249
249
static_cast <T>(static_cast <U>(scale[idx]) *
250
250
(buffer[i] - row_mean) * row_inv_var +
251
251
static_cast <U>(bias[idx]));
252
252
}
253
253
} else {
254
254
for (int i = 0 ; i < cols_this_thread; ++i) {
255
- int idx = last_tid_idx * write_times + i;
255
+ int64_t idx = static_cast < int64_t >( last_tid_idx) * write_times + i;
256
256
row_dst[idx] =
257
257
static_cast <T>(static_cast <U>(scale[idx]) *
258
258
(buffer[i] - row_mean) * row_inv_var +
@@ -263,27 +263,27 @@ struct LayerNormDataWriter<T, U, IsSameType, 1> {
263
263
if (valid_scale) {
264
264
if (threadIdx .x < last_tid_idx) {
265
265
for (int i = 0 ; i < cols_this_thread; ++i) {
266
- int idx = threadIdx .x + last_tid_idx * i;
266
+ int64_t idx = threadIdx .x + static_cast < int64_t >( last_tid_idx) * i;
267
267
row_dst[idx] = static_cast <T>(static_cast <U>(scale[idx]) *
268
268
(buffer[i] - row_mean) * row_inv_var);
269
269
}
270
270
} else {
271
271
for (int i = 0 ; i < cols_this_thread; ++i) {
272
- int idx = last_tid_idx * write_times + i;
272
+ int64_t idx = static_cast < int64_t >( last_tid_idx) * write_times + i;
273
273
row_dst[idx] = static_cast <T>(static_cast <U>(scale[idx]) *
274
274
(buffer[i] - row_mean) * row_inv_var);
275
275
}
276
276
}
277
277
} else {
278
278
if (threadIdx .x < last_tid_idx) {
279
279
for (int i = 0 ; i < cols_this_thread; ++i) {
280
- int idx = threadIdx .x + last_tid_idx * i;
280
+ int64_t idx = threadIdx .x + static_cast < int64_t >( last_tid_idx) * i;
281
281
row_dst[idx] = static_cast <T>((buffer[i] - row_mean) * row_inv_var +
282
282
static_cast <U>(bias[idx]));
283
283
}
284
284
} else {
285
285
for (int i = 0 ; i < cols_this_thread; ++i) {
286
- int idx = last_tid_idx * write_times + i;
286
+ int64_t idx = static_cast < int64_t >( last_tid_idx) * write_times + i;
287
287
row_dst[idx] = static_cast <T>((buffer[i] - row_mean) * row_inv_var +
288
288
static_cast <U>(bias[idx]));
289
289
}
0 commit comments