Skip to content

Commit bbafcc3

Browse files
CUDA: Drop Maxwell compatibility
1 parent bcdc4de commit bbafcc3

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

csrc/common.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
// TODO: Let's make some of these constexpr and put in a namespace.
44

5-
#define BNB_CC_MAXWELL 500
6-
#define BNB_CC_MAXWELL2 520
7-
#define BNB_CC_MAXWELL2_X1 530
85
#define BNB_CC_PASCAL 600
96
#define BNB_CC_PASCAL_X2 620
107
#define BNB_CC_VOLTA 700
@@ -17,7 +14,6 @@
1714
#define BNB_CC_HOPPER 900
1815
#define BNB_CC_BLACKWELL 1000
1916

20-
#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1)
2117
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
2218
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
2319
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)

csrc/kernels.cu

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,15 +1767,7 @@ template <typename T, int THREADS, int SPARSE_DECOMP>
17671767
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
17681768
void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
17691769

1770-
// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
1771-
// Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
1772-
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE
1773-
using TReduction = T;
1774-
#else
1775-
using TReduction = float;
1776-
#endif
1777-
1778-
using BlockReduceT = cub::BlockReduce<TReduction, THREADS>;
1770+
using BlockReduceT = cub::BlockReduce<T, THREADS>;
17791771

17801772
// One block per row.
17811773
// Threads load column values in a striped arrangement.
@@ -1785,27 +1777,27 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
17851777
// We then do a blockwise reduction to determine the row's absmax.
17861778

17871779
__shared__ typename BlockReduceT::TempStorage temp_storage;
1788-
__shared__ TReduction smem_row_absmax;
1780+
__shared__ T smem_row_absmax;
17891781

17901782
const int row_id = blockIdx.x;
17911783
const T* row_data = A + (row_id * cols);
17921784

17931785
// Threads will read the row values in a striped access pattern and find a local absmax.
1794-
TReduction row_local_absmax = -FLT_MIN;
1786+
T row_local_absmax = -FLT_MIN;
17951787
for (int i = threadIdx.x; i < cols; i += THREADS) {
1796-
const TReduction absval = fabsf(__ldcs(&(row_data[i])));
1788+
const T absval = fabsf(__ldcs(&(row_data[i])));
17971789

17981790
// For sparse decomposition, values outside of the threshold are not to be
17991791
// included when calculating the row's absmax.
18001792
if constexpr (SPARSE_DECOMP) {
1801-
row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax);
1793+
row_local_absmax = fmaxf(row_local_absmax, absval < T(threshold) ? absval : row_local_absmax);
18021794
} else {
18031795
row_local_absmax = fmaxf(row_local_absmax, absval);
18041796
}
18051797
}
18061798

18071799
// Reduce thread-local absmax across the block.
1808-
const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols);
1800+
const T row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols);
18091801
if (threadIdx.x == 0) {
18101802
// Save our block's absmax to shared memory for the quantization step.
18111803
rowStats[row_id] = smem_row_absmax = row_absmax;

0 commit comments

Comments
 (0)