Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.

Commit 9ea57a8

Browse files
authored
Improve CUDA capability handling (#329)
We computed a kernel's capabilities by taking the loose intersection of the stated kernel capabilities (or the default) and the capabilities reported to be supported by CMake/Torch. However, this led to issues with e.g. capability 8.9, which is not in these lists (anymore?), but is fine to compile for. To solve this issue, we will ignore the capabilities reported by CMake/Torch and instead use our own list of capabilities for the loose intersection with the kernel capabilities. This list is the list of all capabilities supported by a CUDA version minus some really old capabilities that are not supported by Torch anyway. This behavior is used by enabling the new `BUILD_ALL_SUPPORTED_ARCHS` CMake option (which is the default for the Nix and Windows builders). When `BUILD_ALL_SUPPORTED_ARCHS` is not set, we will try to detect the capability of the user's CUDA GPU. This speeds up development - since one then only has to compile for a single capability. If this fails for some reason, we'll revert to using all capabilities as if `BUILD_ALL_SUPPORTED_ARCHS` was set.
1 parent 0253b68 commit 9ea57a8

File tree

7 files changed

+91
-44
lines changed

7 files changed

+91
-44
lines changed

build2cmake/src/templates/cuda/kernel.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ if(GPU_LANG STREQUAL "CUDA")
1818
{% if cuda_capabilities %}
1919
cuda_archs_loose_intersection({{kernel_name}}_ARCHS "{{ cuda_capabilities|join(";") }}" "${CUDA_ARCHS}")
2020
{% else %}
21-
cuda_archs_loose_intersection({{kernel_name}}_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}" "${CUDA_ARCHS}")
21+
set({{kernel_name}}_ARCHS "${CUDA_KERNEL_ARCHS}")
2222
{% endif %}
2323
message(STATUS "Capabilities for kernel {{kernel_name}}: {{ '${' + kernel_name + '_ARCHS}'}}")
2424
set_gencode_flags_for_srcs(SRCS {{'"${' + kernel_name + '_SRC}"'}} CUDA_ARCHS "{{ '${' + kernel_name + '_ARCHS}'}}")

build2cmake/src/templates/cuda/preamble.cmake

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ include(FetchContent)
99
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
1010
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
1111

12-
set(CUDA_SUPPORTED_ARCHS "{{ cuda_supported_archs }}")
13-
1412
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
1513

1614
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
@@ -50,6 +48,8 @@ if (NOT TARGET_DEVICE STREQUAL "cuda" AND
5048
return()
5149
endif()
5250

51+
option(BUILD_ALL_SUPPORTED_ARCHS "Build all supported architectures" off)
52+
5353
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
5454
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
5555
set(CUDA_DEFAULT_KERNEL_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0+PTX")
@@ -90,13 +90,26 @@ endif()
9090

9191

9292
if(GPU_LANG STREQUAL "CUDA")
93-
clear_cuda_arches(CUDA_ARCH_FLAGS)
94-
extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
95-
message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
96-
# Filter the target architectures by the supported supported archs
97-
# since for some files we will build for all CUDA_ARCHS.
98-
cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
99-
message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
93+
# This clears out -gencode arguments from `CMAKE_CUDA_FLAGS`, which we need
94+
# to set our own set of capabilities.
95+
clear_gencode_flags()
96+
97+
# Get the capabilities without +PTX suffixes, so that we can use them as
98+
# the target archs in the loose intersection with a kernel's capabilities.
99+
cuda_remove_ptx_suffixes(CUDA_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}")
100+
message(STATUS "CUDA supported base architectures: ${CUDA_ARCHS}")
101+
102+
if(BUILD_ALL_SUPPORTED_ARCHS)
103+
set(CUDA_KERNEL_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}")
104+
else()
105+
try_run_python(CUDA_KERNEL_ARCHS SUCCESS "import torch; cc=torch.cuda.get_device_capability(); print(f\"{cc[0]}.{cc[1]}\")" "Failed to get CUDA capability")
106+
if(NOT SUCCESS)
107+
message(WARNING "Failed to detect CUDA capability, using default capabilities.")
108+
set(CUDA_KERNEL_ARCHS "${CUDA_DEFAULT_KERNEL_ARCHS}")
109+
endif()
110+
endif()
111+
112+
message(STATUS "CUDA supported kernel architectures: ${CUDA_KERNEL_ARCHS}")
100113

101114
if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
102115
list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")

build2cmake/src/templates/utils.cmake

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,29 @@ function (run_python OUT EXPR ERR_MSG)
4242
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
4343
endfunction()
4444

45+
#
46+
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
47+
# has trailing whitespace stripped. If an error is encountered when running
48+
# python, `SUCCESS` is set to FALSE. If successful, `SUCCESS` is set to TRUE.
49+
#
50+
function (try_run_python OUT SUCCESS EXPR)
51+
execute_process(
52+
COMMAND
53+
"${Python3_EXECUTABLE}" "-c" "${EXPR}"
54+
OUTPUT_VARIABLE PYTHON_OUT
55+
RESULT_VARIABLE PYTHON_ERROR_CODE
56+
ERROR_QUIET
57+
OUTPUT_STRIP_TRAILING_WHITESPACE)
58+
59+
if(NOT PYTHON_ERROR_CODE EQUAL 0)
60+
set(${SUCCESS} FALSE PARENT_SCOPE)
61+
set(${OUT} "" PARENT_SCOPE)
62+
else()
63+
set(${SUCCESS} TRUE PARENT_SCOPE)
64+
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
65+
endif()
66+
endfunction()
67+
4568
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
4669
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
4770
macro (append_cmake_prefix_path PKG EXPR)
@@ -152,34 +175,28 @@ macro(string_to_ver OUT_VER IN_STR)
152175
endmacro()
153176

154177
#
155-
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
156-
# `CUDA_ARCH_FLAGS`.
178+
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS`.
157179
#
158180
# Example:
159181
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
160-
# clear_cuda_arches(CUDA_ARCH_FLAGS)
161-
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
182+
# clear_gencode_flags()
162183
# CMAKE_CUDA_FLAGS="-Wall"
163184
#
164-
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
165-
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
166-
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
167-
${CMAKE_CUDA_FLAGS})
168-
185+
macro(clear_gencode_flags)
169186
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
170187
# and passed back via the `CUDA_ARCHITECTURES` property.
171188
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
172189
${CMAKE_CUDA_FLAGS})
173190
endmacro()
174191

175192
#
176-
# Extract unique CUDA architectures from a list of compute capabilities codes in
177-
# the form `<major><minor>[<letter>]`, convert them to the form sort
178-
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
193+
# Extract unique CUDA architectures from a list of compute capabilities codes in
194+
# the form `<major><minor>[<letter>]`, convert them to the form sort
195+
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
179196
# stores them in `OUT_ARCHES`.
180197
#
181198
# Example:
182-
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
199+
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
183200
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
184201
# OUT_ARCHES="7.5;...;9.0"
185202
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
@@ -200,15 +217,15 @@ function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
200217
endfunction()
201218

202219
#
203-
# For a specific file set the `-gencode` flag in compile options conditionally
204-
# for the CUDA language.
220+
# For a specific file set the `-gencode` flag in compile options conditionally
221+
# for the CUDA language.
205222
#
206223
# Example:
207224
# set_gencode_flag_for_srcs(
208225
# SRCS "foo.cu"
209226
# ARCH "compute_75"
210227
# CODE "sm_75")
211-
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
228+
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
212229
# `foo.cu` (only for the CUDA language).
213230
#
214231
macro(set_gencode_flag_for_srcs)
@@ -228,14 +245,14 @@ macro(set_gencode_flag_for_srcs)
228245
endmacro(set_gencode_flag_for_srcs)
229246

230247
#
231-
# For a list of source files set the `-gencode` flags in the files specific
248+
# For a list of source files set the `-gencode` flags in the files specific
232249
# compile options (specifically for the CUDA language).
233250
#
234251
# arguments are:
235252
# SRCS: list of source files
236253
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
237254
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
238-
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
255+
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
239256
# that is larger than BUILD_PTX_FOR_ARCH.
240257
#
241258
macro(set_gencode_flags_for_srcs)
@@ -383,12 +400,14 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
383400
endforeach()
384401
set(_CUDA_ARCHS ${_FINAL_ARCHS})
385402

403+
list(SORT _CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
404+
386405
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
387406
endfunction()
388407

389408
#
390-
# For the given `SRC_ROCM_ARCHS` list of architecture versions in the form
391-
# `<name>` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list.
409+
# For the given `SRC_ROCM_ARCHS` list of architecture versions in the form
410+
# `<name>` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list.
392411
# The loose intersection is defined as:
393412
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
394413
# where `<=` is the version comparison operator.
@@ -404,28 +423,48 @@ endfunction()
404423
#
405424
function(hip_archs_loose_intersection OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
406425
list(REMOVE_DUPLICATES SRC_ROCM_ARCHS)
407-
426+
408427
# ROCm architectures are typically in format gfxNNN or gfxNNNx where N is a digit
409428
# and x is a letter. We can sort them by string comparison which works for this format.
410429
list(SORT SRC_ROCM_ARCHS COMPARE STRING ORDER ASCENDING)
411-
430+
412431
set(_ROCM_ARCHS)
413-
432+
414433
# Find the intersection of supported architectures
415434
foreach(_SRC_ARCH ${SRC_ROCM_ARCHS})
416435
if(_SRC_ARCH IN_LIST TGT_ROCM_ARCHS)
417436
list(APPEND _ROCM_ARCHS ${_SRC_ARCH})
418437
endif()
419438
endforeach()
420-
439+
421440
list(REMOVE_DUPLICATES _ROCM_ARCHS)
422441
set(${OUT_ROCM_ARCHS} ${_ROCM_ARCHS} PARENT_SCOPE)
423442
endfunction()
424443

444+
function(cuda_remove_ptx_suffixes OUT_CUDA_ARCHS CUDA_ARCHS)
445+
set(_CUDA_ARCHS "${CUDA_ARCHS}")
446+
447+
# handle +PTX suffix: separate base arch for matching, record PTX requests
448+
foreach(_arch ${CUDA_ARCHS})
449+
if(_arch MATCHES "\\+PTX$")
450+
string(REPLACE "+PTX" "" _base "${_arch}")
451+
list(REMOVE_ITEM _CUDA_ARCHS "${_arch}")
452+
list(APPEND _CUDA_ARCHS "${_base}")
453+
endif()
454+
endforeach()
455+
456+
list(REMOVE_DUPLICATES _CUDA_ARCHS)
457+
list(SORT _CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
458+
459+
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
460+
endfunction()
461+
462+
463+
425464
#
426465
# Override the GPU architectures detected by cmake/torch and filter them by
427466
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
428-
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
467+
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
429468
# the architectures on a per file basis.
430469
#
431470
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.

build2cmake/src/torch/cuda.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,6 @@ static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
1818
static WINDOWS_UTILS: &str = include_str!("../templates/windows.cmake");
1919
static REGISTRATION_H: &str = include_str!("../templates/registration.h");
2020
static HIPIFY: &str = include_str!("../templates/cuda/hipify.py");
21-
static CUDA_SUPPORTED_ARCHS_JSON: &str = include_str!("../cuda_supported_archs.json");
22-
23-
fn cuda_supported_archs() -> String {
24-
let supported_archs: Vec<String> = serde_json::from_str(CUDA_SUPPORTED_ARCHS_JSON)
25-
.expect("Error parsing supported CUDA archs");
26-
supported_archs.join(";")
27-
}
2821

2922
pub fn write_torch_ext_cuda(
3023
env: &Environment,
@@ -417,7 +410,6 @@ pub fn render_preamble(
417410
cuda_maxver => cuda_maxver.map(|v| v.to_string()),
418411
torch_minver => torch_minver.map(|v| v.to_string()),
419412
torch_maxver => torch_maxver.map(|v| v.to_string()),
420-
cuda_supported_archs => cuda_supported_archs(),
421413
platform => env::consts::OS
422414
},
423415
&mut *write,

lib/torch-extension/arch.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ stdenv.mkDerivation (prevAttrs: {
216216
dontSetupCUDAToolkitCompilers = true;
217217

218218
cmakeFlags = [
219+
(lib.cmakeBool "BUILD_ALL_SUPPORTED_ARCHS" true)
219220
(lib.cmakeFeature "Python_EXECUTABLE" "${python3.withPackages (ps: [ torch ])}/bin/python")
220221
# Fix: file RPATH_CHANGE could not write new RPATH, we are rewriting
221222
# rpaths anyway.

pkgs/build2cmake/default.nix

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ rustPlatform.buildRustPackage {
2121
|| file.name == "Cargo.lock"
2222
|| file.name == "pyproject.toml"
2323
|| file.name == "pyproject_universal.toml"
24-
|| file.name == "cuda_supported_archs.json"
2524
|| file.name == "python_dependencies.json"
2625
|| (builtins.any file.hasExt [
2726
"cmake"

scripts/windows/builder.ps1

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ function Get-CMakeConfigureArgs {
350350
$kwargs = @("..", "-G", "Visual Studio 17 2022", "-A", $vsArch)
351351
}
352352

353+
# Build for all supported GPU archs, not just the detected arch.
354+
$kwargs += "-DBUILD_ALL_SUPPORTED_ARCHS"
355+
353356
# Detect Python from current environment
354357
$pythonExe = (Get-Command python -ErrorAction SilentlyContinue).Source
355358
if ($pythonExe) {

0 commit comments

Comments
 (0)