@@ -61,16 +61,17 @@ template <typename T, int DATA_TYPE>
6161void dequantizeBlockwise (
6262 float * code, unsigned char * A, float * absmax, T* out, int blocksize, const int n, cudaStream_t stream
6363) {
64- // printf("stream==%d\n",stream);
65- int num_blocks = n / blocksize;
66- num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1 ;
67- int tile_size = (DATA_TYPE > 0 ) ? 1024 : 512 ;
64+ constexpr int tile_size = (DATA_TYPE > 0 ) ? 1024 : 512 ;
65+
66+ // Upcast to int64 to avoid overflow for large n
67+ int grid_blocks = ((int64_t )n + tile_size - 1 ) / tile_size;
68+
6869 if (DATA_TYPE > 0 )
6970 kDequantizeBlockwise <T, 512 , 64 , 8 , DATA_TYPE>
70- <<<(n + tile_size - 1 ) / tile_size , 64 , 0 , stream>>> (code, A, absmax, out, blocksize / 2 , n);
71+ <<<grid_blocks , 64 , 0 , stream>>> (code, A, absmax, out, blocksize / 2 , n);
7172 else
7273 kDequantizeBlockwise <T, 512 , 64 , 8 , DATA_TYPE>
73- <<<(n + tile_size - 1 ) / tile_size , 64 , 0 , stream>>> (code, A, absmax, out, blocksize, n);
74+ <<<grid_blocks , 64 , 0 , stream>>> (code, A, absmax, out, blocksize, n);
7475
7576 CUDA_CHECK_RETURN (cudaPeekAtLastError ());
7677}
0 commit comments