55
66#include " kernels.cuh"
77#include " common.cuh"
8+ #include < cuda_fp16.h>
89#include < cub/block/block_radix_sort.cuh>
910#include < cub/warp/warp_reduce.cuh>
1011#include < cub/block/block_load.cuh>
@@ -2141,7 +2142,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
21412142template <typename T, int THREADS, int SPARSE_DECOMP>
21422143__launch_bounds__ (1024 , BNB_MAX_THREADS_PER_SM / 1024 )
21432144__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t * out, float * rowStats, float threshold, int rows, int cols) {
2144- using BlockReduceT = cub::BlockReduce<float , THREADS>;
2145+ using BlockReduceT = cub::BlockReduce<T , THREADS>;
21452146
21462147 // One block per row.
21472148 // Threads load column values in a striped arrangement.
@@ -2151,27 +2152,27 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
21512152 // We then do a blockwise reduction to determine the row's absmax.
21522153
21532154 __shared__ typename BlockReduceT::TempStorage temp_storage;
2154- __shared__ float smem_row_absmax;
2155+ __shared__ T smem_row_absmax;
21552156
21562157 const int row_id = blockIdx .x ;
2157- const T* __restrict__ row_data = A + (row_id * cols);
2158+ const T* row_data = A + (row_id * cols);
21582159
21592160 // Threads will read the row values in a striped access pattern and find a local absmax.
2160- float row_local_absmax = -FLT_MIN;
2161+ T row_local_absmax = -FLT_MIN;
21612162 for (int i = threadIdx .x ; i < cols; i += THREADS) {
2162- const float absval = fabsf (__ldcs (&(row_data[i])));
2163+ const T absval = fabsf (__ldcs (&(row_data[i])));
21632164
21642165 // For sparse decomposition, values outside of the threshold are not to be
21652166 // included when calculating the row's absmax.
21662167 if constexpr (SPARSE_DECOMP) {
2167- row_local_absmax = fmaxf (row_local_absmax, absval < threshold ? absval : row_local_absmax);
2168+ row_local_absmax = fmaxf (row_local_absmax, absval < T ( threshold) ? absval : row_local_absmax);
21682169 } else {
21692170 row_local_absmax = fmaxf (row_local_absmax, absval);
21702171 }
21712172 }
21722173
21732174 // Reduce thread-local absmax across the block.
2174- const float row_absmax = BlockReduceT (temp_storage).Reduce (row_local_absmax, cub::Max (), cols);
2175+ const T row_absmax = BlockReduceT (temp_storage).Reduce (row_local_absmax, cub::Max (), cols);
21752176 if (threadIdx .x == 0 ) {
21762177 // Save our block's absmax to shared memory for the quantization step.
21772178 rowStats[row_id] = smem_row_absmax = row_absmax;
@@ -2181,13 +2182,14 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
21812182 // Quantize row-wise.
21822183 const float scale = __fdividef (127 .0f , smem_row_absmax);
21832184 for (int i = threadIdx .x ; i < cols; i += THREADS) {
2185+ float val = row_data[i];
2186+
21842187 if constexpr (SPARSE_DECOMP) {
21852188 // For sparse decomposition, we do not want to quantize the outliers.
21862189 // Instead they're zeroed out.
2187- float val = row_data[i];
21882190 out[row_id * cols + i] = fabs (val) < threshold ? __float2int_rn (val * scale) : 0 ;
21892191 } else {
2190- out[row_id * cols + i] = __float2int_rn (float (row_data[i]) * scale);
2192+ out[row_id * cols + i] = __float2int_rn (val * scale);
21912193 }
21922194 }
21932195}
0 commit comments