@@ -1767,15 +1767,7 @@ template <typename T, int THREADS, int SPARSE_DECOMP>
17671767__launch_bounds__ (1024 , BNB_MAX_THREADS_PER_SM / 1024 ) __global__
17681768 void kInt8VectorQuant(T* __restrict__ A, int8_t * out, float * rowStats, float threshold, int rows, int cols) {
17691769
1770- // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
1771- // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
1772- #if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE
1773- using TReduction = T;
1774- #else
1775- using TReduction = float ;
1776- #endif
1777-
1778- using BlockReduceT = cub::BlockReduce<TReduction, THREADS>;
1770+ using BlockReduceT = cub::BlockReduce<T, THREADS>;
17791771
17801772 // One block per row.
17811773 // Threads load column values in a striped arrangement.
@@ -1785,27 +1777,27 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
17851777 // We then do a blockwise reduction to determine the row's absmax.
17861778
17871779 __shared__ typename BlockReduceT::TempStorage temp_storage;
1788- __shared__ TReduction smem_row_absmax;
1780+ __shared__ T smem_row_absmax;
17891781
17901782 const int row_id = blockIdx .x ;
17911783 const T* row_data = A + (row_id * cols);
17921784
17931785 // Threads will read the row values in a striped access pattern and find a local absmax.
1794- TReduction row_local_absmax = -FLT_MIN;
1786+ T row_local_absmax = -FLT_MIN;
17951787 for (int i = threadIdx .x ; i < cols; i += THREADS) {
1796- const TReduction absval = fabsf (__ldcs (&(row_data[i])));
1788+ const T absval = fabsf (__ldcs (&(row_data[i])));
17971789
17981790 // For sparse decomposition, values outside of the threshold are not to be
17991791 // included when calculating the row's absmax.
18001792 if constexpr (SPARSE_DECOMP) {
1801- row_local_absmax = fmaxf (row_local_absmax, absval < TReduction (threshold) ? absval : row_local_absmax);
1793+ row_local_absmax = fmaxf (row_local_absmax, absval < T (threshold) ? absval : row_local_absmax);
18021794 } else {
18031795 row_local_absmax = fmaxf (row_local_absmax, absval);
18041796 }
18051797 }
18061798
18071799 // Reduce thread-local absmax across the block.
1808- const TReduction row_absmax = BlockReduceT (temp_storage).Reduce (row_local_absmax, CUB_REDUCTIONOP_MAX, cols);
1800+ const T row_absmax = BlockReduceT (temp_storage).Reduce (row_local_absmax, CUB_REDUCTIONOP_MAX, cols);
18091801 if (threadIdx .x == 0 ) {
18101802 // Save our block's absmax to shared memory for the quantization step.
18111803 rowStats[row_id] = smem_row_absmax = row_absmax;
0 commit comments