Skip to content

Commit fa6f597

Browse files
maxwell compat
1 parent 196c8e0 commit fa6f597

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

csrc/kernels.cu

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2142,7 +2142,16 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
21422142
template<typename T, int THREADS, int SPARSE_DECOMP>
21432143
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
21442144
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
2145-
using BlockReduceT = cub::BlockReduce<T, THREADS>;
2145+
2146+
// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
2147+
// Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
2148+
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE && __CUDACC__
2149+
using TReduction = T;
2150+
#else
2151+
using TReduction = float;
2152+
#endif
2153+
2154+
using BlockReduceT = cub::BlockReduce<TReduction, THREADS>;
21462155

21472156
// One block per row.
21482157
// Threads load column values in a striped arrangement.
@@ -2152,27 +2161,27 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
21522161
// We then do a blockwise reduction to determine the row's absmax.
21532162

21542163
__shared__ typename BlockReduceT::TempStorage temp_storage;
2155-
__shared__ T smem_row_absmax;
2164+
__shared__ TReduction smem_row_absmax;
21562165

21572166
const int row_id = blockIdx.x;
21582167
const T* row_data = A + (row_id * cols);
21592168

21602169
// Threads will read the row values in a striped access pattern and find a local absmax.
2161-
T row_local_absmax = -FLT_MIN;
2170+
TReduction row_local_absmax = -FLT_MIN;
21622171
for (int i = threadIdx.x; i < cols; i += THREADS) {
2163-
const T absval = fabsf(__ldcs(&(row_data[i])));
2172+
const TReduction absval = fabsf(__ldcs(&(row_data[i])));
21642173

21652174
// For sparse decomposition, values outside of the threshold are not to be
21662175
// included when calculating the row's absmax.
21672176
if constexpr (SPARSE_DECOMP) {
2168-
row_local_absmax = fmaxf(row_local_absmax, absval < T(threshold) ? absval : row_local_absmax);
2177+
row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax);
21692178
} else {
21702179
row_local_absmax = fmaxf(row_local_absmax, absval);
21712180
}
21722181
}
21732182

21742183
// Reduce thread-local absmax across the block.
2175-
const T row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
2184+
const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
21762185
if (threadIdx.x == 0) {
21772186
// Save our block's absmax to shared memory for the quantization step.
21782187
rowStats[row_id] = smem_row_absmax = row_absmax;

0 commit comments

Comments
 (0)