Skip to content

Commit 01b7256

Browse files
authored
[large tensor] Use int64_t for CUDA indexing to avoid overflow (PaddlePaddle#76303)
1 parent 5387ef2 commit 01b7256

File tree

68 files changed

+451
-152
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+451
-152
lines changed

paddle/phi/kernels/funcs/broadcast_function.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
212212
using VecType = phi::kps::details::VectorType<Type, VecSize>;
213213
VecType vec_temp;
214214

215-
int thread_offset = threadIdx.x + blockIdx.x * blockDim.x;
215+
int64_t thread_offset =
216+
static_cast<int64_t>(threadIdx.x) +
217+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
216218
const VecType *__restrict__ vec_input =
217219
reinterpret_cast<const VecType *__restrict__>(ins[Index]);
218220
vec_temp = vec_input[thread_offset];

paddle/phi/kernels/funcs/detail/gru_gpu_kernel.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ __global__ void KeFastCollectiveGruGate(T *gate_value,
128128
T c0 = 0.0f;
129129
T b0[Tiled_size];
130130

131-
int COL = blockIdx.x * blockDim.x + threadIdx.x;
131+
int64_t COL =
132+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
133+
static_cast<int64_t>(threadIdx.x);
132134
int Tiled_mask = ((1 << Tiled_size) - 1);
133135
// Tiled matrix multiply using register shift, faster than sm.
134136
if (prev_output_value) {
@@ -185,7 +187,9 @@ __global__ void KeFastCollectiveGruOut(const T *gate_weight,
185187
int frame_size,
186188
ActivationType act_node,
187189
bool origin_mode) {
188-
int COL = blockIdx.x * blockDim.x + threadIdx.x;
190+
int64_t COL =
191+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
192+
static_cast<int64_t>(threadIdx.x);
189193

190194
T a0 = 0.0f;
191195
T b0[Tiled_size];

paddle/phi/kernels/funcs/fake_quantize_functor.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ struct QuantizeDataType<phi::float16> {
2929

3030
template <typename T>
3131
__global__ void FindAbsMaxKernel(const T *in, const int64_t n, T *out) {
32-
int bid = threadIdx.x + blockIdx.x * blockDim.x;
32+
int64_t bid =
33+
static_cast<int64_t>(threadIdx.x) +
34+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
3335
int tid = threadIdx.x;
3436

3537
extern __shared__ char *shared_max_data_tmp[];
@@ -70,7 +72,9 @@ __global__ void ClipAndQuantKernel(const T *in,
7072
const int round_type,
7173
const int64_t n,
7274
T *out) {
73-
int bid = threadIdx.x + blockIdx.x * blockDim.x;
75+
int64_t bid =
76+
static_cast<int64_t>(threadIdx.x) +
77+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
7478
int tid = threadIdx.x;
7579

7680
using ComputeDataType = typename QuantizeDataType<T>::type;
@@ -155,7 +159,9 @@ __global__ void ClipAndQuantDequantKernel(const T *in,
155159
const int round_type,
156160
const int64_t n,
157161
T *out) {
158-
int bid = threadIdx.x + blockIdx.x * blockDim.x;
162+
int64_t bid =
163+
static_cast<int64_t>(threadIdx.x) +
164+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
159165
int tid = threadIdx.x;
160166

161167
using ComputeDataType = typename QuantizeDataType<T>::type;

paddle/phi/kernels/funcs/fc_functor.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ struct FcTypeTraits<float16> {
6363

6464
template <typename T, bool DoRelu>
6565
__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) {
66-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
66+
int64_t tid =
67+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
68+
static_cast<int64_t>(threadIdx.x);
6769
if (tid < num) {
6870
int bias_idx = tid % K;
6971
const T bias_ptr = bias[bias_idx];

paddle/phi/kernels/funcs/math_function.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ DEFINE_GPU_TRANS(6);
209209

210210
template <typename T>
211211
__global__ void FillConstantKernel(const int N, T* a, const T val) {
212-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
212+
for (int64_t i =
213+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
214+
static_cast<int64_t>(threadIdx.x);
215+
i < N;
213216
i += blockDim.x * gridDim.x) {
214217
a[i] = val;
215218
}

paddle/phi/kernels/funcs/norm_utils.cu.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,9 @@ __global__ void DoubleGradComputeDXWithGlobal(const T *dy,
370370
const int sample_size,
371371
const int64_t num,
372372
T *dx) {
373-
int gid = blockIdx.x * blockDim.x + threadIdx.x;
373+
int64_t gid =
374+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
375+
static_cast<int64_t>(threadIdx.x);
374376
int stride = blockDim.x * gridDim.x;
375377
if (ddscale != nullptr) {
376378
for (int64_t i = gid; i < num; i += stride) {
@@ -397,7 +399,9 @@ __global__ void DoubleGradComputeDDYWithGlobal(const T *ddx,
397399
const int sample_size,
398400
const int64_t num,
399401
T *ddy) {
400-
int gid = blockIdx.x * blockDim.x + threadIdx.x;
402+
int64_t gid =
403+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
404+
static_cast<int64_t>(threadIdx.x);
401405
int stride = blockDim.x * gridDim.x;
402406

403407
if (ddx != nullptr) {

paddle/phi/kernels/funcs/quant_dequant.h

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,13 @@ __global__ void QuantKernel(const T* input,
9191
const int round_type,
9292
const float max_bound,
9393
const float min_bound) {
94-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
95-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
94+
int64_t n_id =
95+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
96+
static_cast<int64_t>(threadIdx.x))
97+
<< 2;
98+
int64_t m_id =
99+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
100+
static_cast<int64_t>(threadIdx.y);
96101

97102
bool check = ((m_id < m) && (n_id < n));
98103
if (check) {
@@ -118,8 +123,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
118123
const int round_type,
119124
const float max_bound,
120125
const float min_bound) {
121-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
122-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
126+
int64_t n_id =
127+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
128+
static_cast<int64_t>(threadIdx.x))
129+
<< 2;
130+
int64_t m_id =
131+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
132+
static_cast<int64_t>(threadIdx.y);
123133

124134
bool check = ((m_id < m) && (n_id < n));
125135
if (check) {
@@ -145,8 +155,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
145155
const int round_type,
146156
const float max_bound,
147157
const float min_bound) {
148-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 3;
149-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
158+
int64_t n_id =
159+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
160+
static_cast<int64_t>(threadIdx.x)) *
161+
3;
162+
int64_t m_id =
163+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
164+
static_cast<int64_t>(threadIdx.y);
150165

151166
bool check = ((m_id < m) && (n_id < n));
152167
if (check) {
@@ -170,8 +185,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
170185
const int round_type,
171186
const float max_bound,
172187
const float min_bound) {
173-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
174-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
188+
int64_t n_id =
189+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
190+
static_cast<int64_t>(threadIdx.x)) *
191+
2;
192+
int64_t m_id =
193+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
194+
static_cast<int64_t>(threadIdx.y);
175195

176196
bool check = ((m_id < m) && (n_id < n));
177197
if (check) {
@@ -193,8 +213,12 @@ __global__ void QuantKernelWithVecSize(const T* input,
193213
const int round_type,
194214
const float max_bound,
195215
const float min_bound) {
196-
int n_id = (blockIdx.x * blockDim.x + threadIdx.x);
197-
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
216+
int64_t n_id =
217+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
218+
static_cast<int64_t>(threadIdx.x));
219+
int64_t m_id =
220+
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y) +
221+
static_cast<int64_t>(threadIdx.y);
198222

199223
bool check = ((m_id < m) && (n_id < n));
200224
if (check) {
@@ -320,7 +344,10 @@ __global__ void DequantKernel(T* output,
320344
const float* dequant_out_scale_data) {
321345
int numel = m * n;
322346
int stride = blockDim.x * gridDim.x * VecSize;
323-
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
347+
int64_t idx =
348+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
349+
static_cast<int64_t>(threadIdx.x)) *
350+
VecSize;
324351
int col_id = idx % n;
325352

326353
phi::AlignedVector<int32_t, VecSize> in_vec;
@@ -366,7 +393,10 @@ __global__ void DequantKernelWithScaleOfInputAndWeight(
366393
float quant_max_bound) {
367394
int numel = m * n;
368395
int stride = blockDim.x * gridDim.x * VecSize;
369-
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
396+
int64_t idx =
397+
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
398+
static_cast<int64_t>(threadIdx.x)) *
399+
VecSize;
370400
int col_id = idx % n;
371401

372402
phi::AlignedVector<int32_t, VecSize> in_vec;

paddle/phi/kernels/funcs/scatter.cu.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ inline DenseTensor restride_dim(const phi::DenseTensor& src,
402402
template <int nt, int vt, typename func_t>
403403
__global__ void scatter_gather_elementwise_kernel(int N, func_t f) {
404404
constexpr int nv = nt * vt;
405-
int idx = nv * blockIdx.x + threadIdx.x;
405+
int64_t idx =
406+
nv * static_cast<int64_t>(blockIdx.x) + static_cast<int64_t>(threadIdx.x);
406407

407408
#pragma unroll
408409
for (int i = 0; i < vt; ++i) {

paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ __global__ void FlattenIndicesKernel(const IntT* indices,
2626
const int64_t non_zero_num,
2727
const int64_t sparse_dim,
2828
IntT* out) {
29-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
29+
int64_t tid =
30+
static_cast<int64_t>(threadIdx.x) +
31+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
3032
phi::funcs::sparse::FlattenIndices<IntT>(indices,
3133
sparse_offsets,
3234
non_zero_num,
@@ -42,7 +44,9 @@ __global__ void IndexToCoordinateKernel(const IntT* index,
4244
const int64_t non_zero_num,
4345
const int64_t sparse_dim,
4446
IntT* indices) {
45-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
47+
int64_t tid =
48+
static_cast<int64_t>(threadIdx.x) +
49+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
4650
IndexToCoordinate(index,
4751
dims,
4852
non_zero_num,

paddle/phi/kernels/funcs/sparse/scatter.cu.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ __global__ void ScatterKernel(const T* input,
4141
const int rulebook_len,
4242
const int channels,
4343
T* out) {
44-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
44+
int64_t tid =
45+
static_cast<int64_t>(threadIdx.x) +
46+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
4547
const int vec_channels = channels / VecSize;
4648
using LoadT = phi::AlignedVector<T, VecSize>;
4749
using StoreT = phi::AlignedVector<T, VecSize>;
@@ -82,7 +84,9 @@ __global__ void ScatterKernelV2(const T* input,
8284
const int channels,
8385
const int buffer_counts,
8486
T* out) {
85-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
87+
int64_t tid =
88+
static_cast<int64_t>(threadIdx.x) +
89+
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
8690
const int vec_channels = channels / VecSize;
8791
using LoadT = phi::AlignedVector<T, VecSize>;
8892
using StoreT = phi::AlignedVector<T, VecSize>;

0 commit comments

Comments
 (0)