@@ -2142,7 +2142,16 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
21422142template <typename T, int THREADS, int SPARSE_DECOMP>
21432143__launch_bounds__ (1024 , BNB_MAX_THREADS_PER_SM / 1024 )
21442144__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t * out, float * rowStats, float threshold, int rows, int cols) {
2145- using BlockReduceT = cub::BlockReduce<T, THREADS>;
2145+
2146+ // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
2147+ // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
2148+ #if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE && __CUDACC__
2149+ using TReduction = T;
2150+ #else
2151+ using TReduction = float ;
2152+ #endif
2153+
2154+ using BlockReduceT = cub::BlockReduce<TReduction, THREADS>;
21462155
21472156 // One block per row.
21482157 // Threads load column values in a striped arrangement.
@@ -2152,27 +2161,27 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
21522161 // We then do a blockwise reduction to determine the row's absmax.
21532162
21542163 __shared__ typename BlockReduceT::TempStorage temp_storage;
2155- __shared__ T smem_row_absmax;
2164+ __shared__ TReduction smem_row_absmax;
21562165
21572166 const int row_id = blockIdx .x ;
21582167 const T* row_data = A + (row_id * cols);
21592168
21602169 // Threads will read the row values in a striped access pattern and find a local absmax.
2161- T row_local_absmax = -FLT_MIN;
2170+ TReduction row_local_absmax = -FLT_MIN;
21622171 for (int i = threadIdx .x ; i < cols; i += THREADS) {
2163- const T absval = fabsf (__ldcs (&(row_data[i])));
2172+ const TReduction absval = fabsf (__ldcs (&(row_data[i])));
21642173
21652174 // For sparse decomposition, values outside of the threshold are not to be
21662175 // included when calculating the row's absmax.
21672176 if constexpr (SPARSE_DECOMP) {
2168- row_local_absmax = fmaxf (row_local_absmax, absval < T (threshold) ? absval : row_local_absmax);
2177+ row_local_absmax = fmaxf (row_local_absmax, absval < TReduction (threshold) ? absval : row_local_absmax);
21692178 } else {
21702179 row_local_absmax = fmaxf (row_local_absmax, absval);
21712180 }
21722181 }
21732182
21742183 // Reduce thread-local absmax across the block.
2175- const T row_absmax = BlockReduceT (temp_storage).Reduce (row_local_absmax, cub::Max (), cols);
2184+ const TReduction row_absmax = BlockReduceT (temp_storage).Reduce (row_local_absmax, cub::Max (), cols);
21762185 if (threadIdx .x == 0 ) {
21772186 // Save our block's absmax to shared memory for the quantization step.
21782187 rowStats[row_id] = smem_row_absmax = row_absmax;
0 commit comments