@@ -91,8 +91,13 @@ __global__ void QuantKernel(const T* input,
9191 const int round_type,
9292 const float max_bound,
9393 const float min_bound) {
94- int n_id = (blockIdx.x * blockDim.x + threadIdx.x ) << 2 ;
95- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
94+ int64_t n_id =
95+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
96+ static_cast <int64_t >(threadIdx.x ))
97+ << 2 ;
98+ int64_t m_id =
99+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
100+ static_cast <int64_t >(threadIdx.y );
96101
97102 bool check = ((m_id < m) && (n_id < n));
98103 if (check) {
@@ -118,8 +123,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
118123 const int round_type,
119124 const float max_bound,
120125 const float min_bound) {
121- int n_id = (blockIdx.x * blockDim.x + threadIdx.x ) << 2 ;
122- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
126+ int64_t n_id =
127+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
128+ static_cast <int64_t >(threadIdx.x ))
129+ << 2 ;
130+ int64_t m_id =
131+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
132+ static_cast <int64_t >(threadIdx.y );
123133
124134 bool check = ((m_id < m) && (n_id < n));
125135 if (check) {
@@ -145,8 +155,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
145155 const int round_type,
146156 const float max_bound,
147157 const float min_bound) {
148- int n_id = (blockIdx.x * blockDim.x + threadIdx.x ) * 3 ;
149- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
158+ int64_t n_id =
159+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
160+ static_cast <int64_t >(threadIdx.x )) *
161+ 3 ;
162+ int64_t m_id =
163+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
164+ static_cast <int64_t >(threadIdx.y );
150165
151166 bool check = ((m_id < m) && (n_id < n));
152167 if (check) {
@@ -170,8 +185,13 @@ __global__ void QuantKernelWithVecSize(const T* input,
170185 const int round_type,
171186 const float max_bound,
172187 const float min_bound) {
173- int n_id = (blockIdx.x * blockDim.x + threadIdx.x ) * 2 ;
174- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
188+ int64_t n_id =
189+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
190+ static_cast <int64_t >(threadIdx.x )) *
191+ 2 ;
192+ int64_t m_id =
193+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
194+ static_cast <int64_t >(threadIdx.y );
175195
176196 bool check = ((m_id < m) && (n_id < n));
177197 if (check) {
@@ -193,8 +213,12 @@ __global__ void QuantKernelWithVecSize(const T* input,
193213 const int round_type,
194214 const float max_bound,
195215 const float min_bound) {
196- int n_id = (blockIdx.x * blockDim.x + threadIdx.x );
197- int m_id = blockIdx.y * blockDim.y + threadIdx.y ;
216+ int64_t n_id =
217+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
218+ static_cast <int64_t >(threadIdx.x ));
219+ int64_t m_id =
220+ static_cast <int64_t >(blockIdx.y ) * static_cast <int64_t >(blockDim.y ) +
221+ static_cast <int64_t >(threadIdx.y );
198222
199223 bool check = ((m_id < m) && (n_id < n));
200224 if (check) {
@@ -320,7 +344,10 @@ __global__ void DequantKernel(T* output,
320344 const float * dequant_out_scale_data) {
321345 int numel = m * n;
322346 int stride = blockDim.x * gridDim.x * VecSize;
323- int idx = (blockIdx.x * blockDim.x + threadIdx.x ) * VecSize;
347+ int64_t idx =
348+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
349+ static_cast <int64_t >(threadIdx.x )) *
350+ VecSize;
324351 int col_id = idx % n;
325352
326353 phi::AlignedVector<int32_t , VecSize> in_vec;
@@ -366,7 +393,10 @@ __global__ void DequantKernelWithScaleOfInputAndWeight(
366393 float quant_max_bound) {
367394 int numel = m * n;
368395 int stride = blockDim.x * gridDim.x * VecSize;
369- int idx = (blockIdx.x * blockDim.x + threadIdx.x ) * VecSize;
396+ int64_t idx =
397+ (static_cast <int64_t >(blockIdx.x ) * static_cast <int64_t >(blockDim.x ) +
398+ static_cast <int64_t >(threadIdx.x )) *
399+ VecSize;
370400 int col_id = idx % n;
371401
372402 phi::AlignedVector<int32_t , VecSize> in_vec;
0 commit comments