Skip to content

Commit 40e6971

Browse files
CUDA 13 build enablement
1 parent b1f80b8 commit 40e6971

File tree

5 files changed

+62
-36
lines changed

5 files changed

+62
-36
lines changed

.github/scripts/build-cuda.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ elif [ "${build_arch}" = "aarch64" ]; then
1212
build_capability="75;80;90"
1313

1414
# CUDA 12.8+: Add sm100/sm120
15-
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;90;100;120"
15+
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* || "${cuda_version}" == 13.*.* ]] && build_capability="75;80;90;100;120"
1616
else
1717
# By default, target Pascal through Hopper.
1818
build_capability="60;70;75;80;86;89;90"
1919

2020
# CUDA 12.8+: Add sm100 and sm120; remove < sm70 to align with PyTorch 2.8+cu128 minimum
2121
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="70;75;80;86;89;90;100;120"
22+
23+
# CUDA 13.0+: Remove < sm75 to align with PyTorch 2.9+cu130 minimum
24+
[[ "${cuda_version}" == 13.*.* ]] && build_capability="75;80;86;89;90;100;120"
2225
fi
2326

2427
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja

.github/workflows/python-package.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,17 @@ jobs:
7272
- os: windows-latest
7373
arch: x86_64
7474
cuda_version:
75-
["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1", "12.9.1"]
75+
["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1", "12.9.1", "13.0.1"]
7676
runs-on: ${{ matrix.os }}
7777
steps:
7878
- uses: actions/checkout@v4
7979
# Windows: We install Cuda on the agent (slow)
80-
- uses: Jimver/cuda-toolkit@c35baa1a18fd1fc9dcf47c5bd839bf30559c0bc3 # v0.2.24
80+
- uses: Jimver/cuda-toolkit@433d453c1fa37d10a3254452fa8e284441c9192d # v0.2.27
8181
if: startsWith(matrix.os, 'windows')
8282
id: cuda-toolkit
8383
with:
84-
# Temporary: Use CUDA 12.9.0 for Windows until 12.9.1 is supported with this action.
85-
cuda: ${{ matrix.cuda_version == '12.9.1' && '12.9.0' || matrix.cuda_version }}
84+
# Temporary: Use CUDA 13.0.0 for Windows until 13.0.1 is supported with this action.
85+
cuda: ${{ matrix.cuda_version == '13.0.1' && '13.0.0' || matrix.cuda_version }}
8686
method: "network"
8787
sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]'
8888
linux-local-args: '["--toolkit"]'

CMakeLists.txt

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,30 +113,36 @@ if(BUILD_CUDA)
113113
)
114114
endif()
115115

116-
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS "11.4")
117-
message(FATAL_ERROR "CUDA Version < 11.4 is not supported")
118-
elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0")
119-
message(FATAL_ERROR "CUDA Version > 12 is not supported")
116+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS "11.8")
117+
message(FATAL_ERROR "CUDA Version < 11.8 is not supported")
118+
elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "14.0")
119+
message(FATAL_ERROR "CUDA Version > 13 is not supported")
120120
endif()
121121

122122
# CMake < 3.23.0 does not define CMAKE_CUDA_ARCHITECTURES_ALL.
123123
if(CMAKE_VERSION VERSION_LESS "3.23.0")
124124
message(STATUS "CMake < 3.23.0; determining CUDA architectures supported...")
125125

126-
# 11.4+ supports these at a minimum.
127-
set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80 86 87)
128-
set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80)
129-
130-
# CUDA 11.8 adds support for Ada and Hopper.
131-
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.8")
132-
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 89 90)
133-
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 90)
134-
endif()
135-
136-
# CUDA 12.8 adds support for Blackwell.
137-
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.8")
138-
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 100 101 120)
139-
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 100 120)
126+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0")
127+
# Starting in CUDA 13.0, Thor Blackwell is renamed to SM110.
128+
# Support for architectures older than Turing (SM75) is removed.
129+
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 75 80 86 87 88 89 90 100 103 110 120 121)
130+
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 80 90 100 110 120)
131+
else()
132+
# 11.8-12.9 supports these at a minimum.
133+
set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80 86 87 89 90)
134+
set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80 90)
135+
136+
# CUDA 12.8 adds support for Blackwell.
137+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.8")
138+
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 100 101 120 121)
139+
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 100 120)
140+
endif()
141+
142+
# CUDA 12.9 adds SM103 (Blackwell B300).
143+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.9")
144+
list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 103)
145+
endif()
140146
endif()
141147
endif()
142148

@@ -252,7 +258,7 @@ endif()
252258

253259
set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
254260
add_library(bitsandbytes SHARED ${SRC_FILES})
255-
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
261+
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
256262
target_include_directories(bitsandbytes PUBLIC csrc include)
257263

258264

csrc/kernels.cu

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
#include <math_constants.h>
1717
#include <mma.h>
1818

19+
#if CCCL_VERSION >= 2008002
20+
#include <cuda/std/functional>
21+
#define CUB_REDUCTIONOP_MAX \
22+
cuda::maximum<> {}
23+
#else
24+
#define CUB_REDUCTIONOP_MAX cub::Max()
25+
#endif
26+
1927
#define HLF_MAX 65504
2028
#define TH 1024
2129
#define NUM 4
@@ -365,7 +373,7 @@ __global__ void kQuantizeBlockwise(
365373
for (int j = 0; j < NUM_PER_TH; j++)
366374
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
367375

368-
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
376+
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX, valid_items);
369377

370378
if (threadIdx.x == 0) {
371379
smem_absmax_value[0] = 1.0f / local_abs_max;
@@ -951,12 +959,12 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b
951959
}
952960

953961
__syncthreads();
954-
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
962+
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, CUB_REDUCTIONOP_MAX, valid_items);
955963
__syncthreads();
956-
local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items);
964+
local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, CUB_REDUCTIONOP_MAX, valid_items);
957965
if (unorm != NULL) {
958966
__syncthreads();
959-
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
967+
local_unorm = BlockReduce(temp_storage.reduce).Sum(local_unorm, valid_items);
960968
}
961969

962970
if (threadIdx.x == 0) {
@@ -1162,13 +1170,13 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b
11621170
}
11631171

11641172
__syncthreads();
1165-
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
1173+
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, CUB_REDUCTIONOP_MAX, valid_items);
11661174
if (threadIdx.x == 0) {
11671175
atomicMax(&new_max1[0], local_max_s1);
11681176
}
11691177
if (unorm != NULL) {
11701178
__syncthreads();
1171-
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
1179+
local_unorm = BlockReduce(temp_storage.reduce).Sum(local_unorm, valid_items);
11721180
if (threadIdx.x == 0) {
11731181
atomicAdd(&unorm[0], local_unorm);
11741182
}
@@ -1473,11 +1481,11 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise(
14731481
}
14741482

14751483
// reduce: 2.51/1.60 -> 2.67/1.69
1476-
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
1477-
new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max());
1484+
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, CUB_REDUCTIONOP_MAX);
1485+
new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, CUB_REDUCTIONOP_MAX);
14781486

14791487
if (OPTIMIZER == ADEMAMIX) {
1480-
new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max());
1488+
new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, CUB_REDUCTIONOP_MAX);
14811489
}
14821490

14831491
if (threadIdx.x == 0) {
@@ -1686,7 +1694,7 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise(
16861694
}
16871695

16881696
// reduce: 2.51/1.60 -> 2.67/1.69
1689-
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
1697+
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, CUB_REDUCTIONOP_MAX);
16901698

16911699
if (threadIdx.x == 0)
16921700
smem_exchange1[0] = new_local_abs_max1;
@@ -1792,7 +1800,7 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
17921800
}
17931801

17941802
// Reduce thread-local absmax across the block.
1795-
const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
1803+
const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols);
17961804
if (threadIdx.x == 0) {
17971805
// Save our block's absmax to shared memory for the quantization step.
17981806
rowStats[row_id] = smem_row_absmax = row_absmax;
@@ -1847,7 +1855,7 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
18471855

18481856
// Reduce thread-local absmax across the block.
18491857
// TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
1850-
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
1858+
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols);
18511859
if (threadIdx.x == 0) {
18521860
// Save our block's absmax to shared memory for the quantization step.
18531861
rowStats[row_id] = row_absmax;

csrc/pythonInterface.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// LICENSE file in the root directory of this source tree.
55

66
#if BUILD_CUDA
7+
#include <cuda_runtime_api.h>
78
#include <ops.cuh>
89
#endif
910
#if BUILD_HIP
@@ -710,7 +711,15 @@ void cprefetch(void* ptr, size_t bytes, int device) {
710711
if (hasPrefetch == 0)
711712
return;
712713

714+
#if CUDART_VERSION >= 13000
715+
cudaMemLocation loc{};
716+
loc.type = cudaMemLocationTypeDevice;
717+
loc.id = device;
718+
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, loc, 0u, 0));
719+
#else
713720
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
721+
#endif
722+
714723
CUDA_CHECK_RETURN(cudaPeekAtLastError());
715724
}
716725

0 commit comments

Comments
 (0)