Skip to content

Commit 8a60bfb

Browse files
authored
Merge pull request #514 from NVlabs/cuda-13
CUDA 13
2 parents 99388eb + 800e401 commit 8a60bfb

File tree

5 files changed

+95
-99
lines changed

5 files changed

+95
-99
lines changed

.github/workflows/main.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
include:
1818
- os: ubuntu-24.04
19-
cuda: "12.8"
19+
cuda: "13.0"
2020
arch: 120
2121
- os: ubuntu-24.04
2222
cuda: "12.8"
@@ -77,15 +77,15 @@ jobs:
7777
include:
7878
- os: windows-2025
7979
visual_studio: "Visual Studio 17 2022"
80-
cuda: "12.9.1"
80+
cuda: "13.0.0"
8181
arch: 120
8282
- os: windows-2025
8383
visual_studio: "Visual Studio 17 2022"
84-
cuda: "12.6.3"
84+
cuda: "12.9.1"
8585
arch: 89
8686
- os: windows-2022
8787
visual_studio: "Visual Studio 17 2022"
88-
cuda: "12.6.3"
88+
cuda: "12.9.1"
8989
arch: 86
9090
- os: windows-2022
9191
visual_studio: "Visual Studio 17 2022"

CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
2-
#
2+
#
33
# Redistribution and use in source and binary forms, with or without modification, are permitted
44
# provided that the following conditions are met:
55
# * Redistributions of source code must retain the above copyright notice, this list of
@@ -10,7 +10,7 @@
1010
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
1111
# to endorse or promote products derived from this software without specific prior written
1212
# permission.
13-
#
13+
#
1414
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
1515
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
1616
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
@@ -165,7 +165,9 @@ else()
165165
set(LATEST_SUPPORTED_CUDA_ARCHITECTURE 120)
166166
endif()
167167

168-
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
168+
if (CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
169+
set(EARLIEST_SUPPORTED_CUDA_ARCHITECTURE 75)
170+
elseif (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
169171
set(EARLIEST_SUPPORTED_CUDA_ARCHITECTURE 50)
170172
else()
171173
set(EARLIEST_SUPPORTED_CUDA_ARCHITECTURE 20)

bindings/torch/setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
ROOT_DIR = os.path.dirname(os.path.dirname(SCRIPT_DIR))
1515

1616
def min_supported_compute_capability(cuda_version):
17-
if cuda_version >= parse_version("12.0"):
17+
if cuda_version >= parse_version("13.0"):
18+
return 75
19+
elif cuda_version >= parse_version("12.0"):
1820
return 50
1921
else:
2022
return 20

dependencies/cuda-cmake-github-actions/scripts/actions/install_cuda_windows.ps1

Lines changed: 74 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,54 +4,55 @@
44

55
# Dictionary of known cuda versions and thier download URLS, which do not follow a consistent pattern :(
66
$CUDA_KNOWN_URLS = @{
7-
"8.0.44" = "http://developer.nvidia.com/compute/cuda/8.0/Prod/network_installers/cuda_8.0.44_win10_network-exe";
8-
"8.0.61" = "http://developer.nvidia.com/compute/cuda/8.0/Prod2/network_installers/cuda_8.0.61_win10_network-exe";
9-
"9.0.176" = "http://developer.nvidia.com/compute/cuda/9.0/Prod/network_installers/cuda_9.0.176_win10_network-exe";
10-
"9.1.85" = "http://developer.nvidia.com/compute/cuda/9.1/Prod/network_installers/cuda_9.1.85_win10_network";
11-
"9.2.148" = "http://developer.nvidia.com/compute/cuda/9.2/Prod2/network_installers2/cuda_9.2.148_win10_network";
12-
"10.0.130" = "http://developer.nvidia.com/compute/cuda/10.0/Prod/network_installers/cuda_10.0.130_win10_network";
13-
"10.1.105" = "http://developer.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.105_win10_network.exe";
14-
"10.1.168" = "http://developer.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.168_win10_network.exe";
15-
"10.1.243" = "http://developer.download.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.243_win10_network.exe";
16-
"10.2.89" = "http://developer.download.nvidia.com/compute/cuda/10.2/Prod/network_installers/cuda_10.2.89_win10_network.exe";
17-
"11.0.1" = "http://developer.download.nvidia.com/compute/cuda/11.0.1/network_installers/cuda_11.0.1_win10_network.exe";
18-
"11.0.2" = "http://developer.download.nvidia.com/compute/cuda/11.0.2/network_installers/cuda_11.0.2_win10_network.exe";
19-
"11.0.3" = "http://developer.download.nvidia.com/compute/cuda/11.0.3/network_installers/cuda_11.0.3_win10_network.exe";
20-
"11.1.0" = "https://developer.download.nvidia.com/compute/cuda/11.1.0/network_installers/cuda_11.1.0_win10_network.exe";
21-
"11.1.1" = "https://developer.download.nvidia.com/compute/cuda/11.1.1/network_installers/cuda_11.1.1_win10_network.exe";
22-
"11.2.0" = "https://developer.download.nvidia.com/compute/cuda/11.2.0/network_installers/cuda_11.2.0_win10_network.exe";
23-
"11.2.1" = "https://developer.download.nvidia.com/compute/cuda/11.2.1/network_installers/cuda_11.2.1_win10_network.exe";
24-
"11.2.2" = "https://developer.download.nvidia.com/compute/cuda/11.2.2/network_installers/cuda_11.2.2_win10_network.exe";
25-
"11.3.0" = "https://developer.download.nvidia.com/compute/cuda/11.3.0/network_installers/cuda_11.3.0_win10_network.exe";
26-
"11.3.1" = "https://developer.download.nvidia.com/compute/cuda/11.3.1/network_installers/cuda_11.3.1_win10_network.exe";
27-
"11.5.0" = "https://developer.download.nvidia.com/compute/cuda/11.5.0/network_installers/cuda_11.5.0_win10_network.exe";
28-
"11.5.1" = "https://developer.download.nvidia.com/compute/cuda/11.5.1/network_installers/cuda_11.5.1_windows_network.exe";
29-
"11.8.0" = "https://developer.download.nvidia.com/compute/cuda/11.8.0/network_installers/cuda_11.8.0_windows_network.exe";
30-
"12.5.0" = "https://developer.download.nvidia.com/compute/cuda/12.5.0/network_installers/cuda_12.5.0_windows_network.exe";
31-
"12.6.3" = "https://developer.download.nvidia.com/compute/cuda/12.6.3/network_installers/cuda_12.6.3_windows_network.exe";
32-
"12.8.0" = "https://developer.download.nvidia.com/compute/cuda/12.8.0/network_installers/cuda_12.8.0_windows_network.exe";
7+
"8.0.44" = "http://developer.nvidia.com/compute/cuda/8.0/Prod/network_installers/cuda_8.0.44_win10_network-exe";
8+
"8.0.61" = "http://developer.nvidia.com/compute/cuda/8.0/Prod2/network_installers/cuda_8.0.61_win10_network-exe";
9+
"9.0.176" = "http://developer.nvidia.com/compute/cuda/9.0/Prod/network_installers/cuda_9.0.176_win10_network-exe";
10+
"9.1.85" = "http://developer.nvidia.com/compute/cuda/9.1/Prod/network_installers/cuda_9.1.85_win10_network";
11+
"9.2.148" = "http://developer.nvidia.com/compute/cuda/9.2/Prod2/network_installers2/cuda_9.2.148_win10_network";
12+
"10.0.130" = "http://developer.nvidia.com/compute/cuda/10.0/Prod/network_installers/cuda_10.0.130_win10_network";
13+
"10.1.105" = "http://developer.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.105_win10_network.exe";
14+
"10.1.168" = "http://developer.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.168_win10_network.exe";
15+
"10.1.243" = "http://developer.download.nvidia.com/compute/cuda/10.1/Prod/network_installers/cuda_10.1.243_win10_network.exe";
16+
"10.2.89" = "http://developer.download.nvidia.com/compute/cuda/10.2/Prod/network_installers/cuda_10.2.89_win10_network.exe";
17+
"11.0.1" = "http://developer.download.nvidia.com/compute/cuda/11.0.1/network_installers/cuda_11.0.1_win10_network.exe";
18+
"11.0.2" = "http://developer.download.nvidia.com/compute/cuda/11.0.2/network_installers/cuda_11.0.2_win10_network.exe";
19+
"11.0.3" = "http://developer.download.nvidia.com/compute/cuda/11.0.3/network_installers/cuda_11.0.3_win10_network.exe";
20+
"11.1.0" = "https://developer.download.nvidia.com/compute/cuda/11.1.0/network_installers/cuda_11.1.0_win10_network.exe";
21+
"11.1.1" = "https://developer.download.nvidia.com/compute/cuda/11.1.1/network_installers/cuda_11.1.1_win10_network.exe";
22+
"11.2.0" = "https://developer.download.nvidia.com/compute/cuda/11.2.0/network_installers/cuda_11.2.0_win10_network.exe";
23+
"11.2.1" = "https://developer.download.nvidia.com/compute/cuda/11.2.1/network_installers/cuda_11.2.1_win10_network.exe";
24+
"11.2.2" = "https://developer.download.nvidia.com/compute/cuda/11.2.2/network_installers/cuda_11.2.2_win10_network.exe";
25+
"11.3.0" = "https://developer.download.nvidia.com/compute/cuda/11.3.0/network_installers/cuda_11.3.0_win10_network.exe";
26+
"11.3.1" = "https://developer.download.nvidia.com/compute/cuda/11.3.1/network_installers/cuda_11.3.1_win10_network.exe";
27+
"11.5.0" = "https://developer.download.nvidia.com/compute/cuda/11.5.0/network_installers/cuda_11.5.0_win10_network.exe";
28+
"11.5.1" = "https://developer.download.nvidia.com/compute/cuda/11.5.1/network_installers/cuda_11.5.1_windows_network.exe";
29+
"11.8.0" = "https://developer.download.nvidia.com/compute/cuda/11.8.0/network_installers/cuda_11.8.0_windows_network.exe";
30+
"12.5.0" = "https://developer.download.nvidia.com/compute/cuda/12.5.0/network_installers/cuda_12.5.0_windows_network.exe";
31+
"12.6.3" = "https://developer.download.nvidia.com/compute/cuda/12.6.3/network_installers/cuda_12.6.3_windows_network.exe";
32+
"12.8.0" = "https://developer.download.nvidia.com/compute/cuda/12.8.0/network_installers/cuda_12.8.0_windows_network.exe";
3333
"12.9.1" = "https://developer.download.nvidia.com/compute/cuda/12.9.1/network_installers/cuda_12.9.1_windows_network.exe";
34+
"13.0.0" = "https://developer.download.nvidia.com/compute/cuda/13.0.0/network_installers/cuda_13.0.0_windows_network.exe";
3435
}
3536

3637
# @todo - change this to be based on _MSC_VER intead, or invert it to be CUDA keyed instead?
3738
$VISUAL_STUDIO_MIN_CUDA = @{
38-
"2019" = "10.1";
39-
"2017" = "10.0"; # Depends on which version of 2017! 9.0 to 10.0 depending on version
40-
"2015" = "8.0"; # might support older, unsure.
39+
"2019" = "10.1";
40+
"2017" = "10.0"; # Depends on which version of 2017! 9.0 to 10.0 depending on version
41+
"2015" = "8.0"; # might support older, unsure.
4142
}
4243

4344
# cuda_runtime.h is in nvcc <= 10.2, but cudart >= 11.0
4445
# @todo - make this easier to vary per CUDA version.
4546
$CUDA_PACKAGES_IN = @(
46-
"nvcc";
47-
"visual_studio_integration";
47+
"nvcc";
48+
"visual_studio_integration";
4849
"cublas";
49-
"cublas_dev";
50+
"cublas_dev";
5051
"curand";
51-
"curand_dev";
52+
"curand_dev";
5253
"nvrtc";
53-
"nvrtc_dev";
54-
"cudart";
54+
"nvrtc_dev";
55+
"cudart";
5556
)
5657

5758

@@ -66,8 +67,8 @@ $CUDA_VERSION_FULL = $env:cuda
6667
# Validate CUDA version, extracting components via regex
6768
$cuda_ver_matched = $CUDA_VERSION_FULL -match "^(?<major>[1-9][0-9]*)\.(?<minor>[0-9]+)\.(?<patch>[0-9]+)$"
6869
if(-not $cuda_ver_matched){
69-
Write-Output "Invalid CUDA version specified, <major>.<minor>.<patch> required. '$CUDA_VERSION_FULL'."
70-
exit 1
70+
Write-Output "Invalid CUDA version specified, <major>.<minor>.<patch> required. '$CUDA_VERSION_FULL'."
71+
exit 1
7172
}
7273
$CUDA_MAJOR=$Matches.major
7374
$CUDA_MINOR=$Matches.minor
@@ -79,16 +80,16 @@ $CUDA_PATCH=$Matches.patch
7980
# Exit if visual studio is too new for the cuda version.
8081
$VISUAL_STUDIO = $env:visual_studio.trim()
8182
if ($VISUAL_STUDIO.length -ge 4) {
82-
$VISUAL_STUDIO_YEAR = $VISUAL_STUDIO.Substring($VISUAL_STUDIO.Length-4)
83-
if ($VISUAL_STUDIO_YEAR.length -eq 4 -and $VISUAL_STUDIO_MIN_CUDA.containsKey($VISUAL_STUDIO_YEAR)){
84-
$MINIMUM_CUDA_VERSION = $VISUAL_STUDIO_MIN_CUDA[$VISUAL_STUDIO_YEAR]
85-
if ([version]$CUDA_VERSION_FULL -lt [version]$MINIMUM_CUDA_VERSION) {
86-
Write-Output "Error: Visual Studio $($VISUAL_STUDIO_YEAR) requires CUDA >= $($MINIMUM_CUDA_VERSION)"
87-
exit 1
88-
}
89-
}
83+
$VISUAL_STUDIO_YEAR = $VISUAL_STUDIO.Substring($VISUAL_STUDIO.Length-4)
84+
if ($VISUAL_STUDIO_YEAR.length -eq 4 -and $VISUAL_STUDIO_MIN_CUDA.containsKey($VISUAL_STUDIO_YEAR)){
85+
$MINIMUM_CUDA_VERSION = $VISUAL_STUDIO_MIN_CUDA[$VISUAL_STUDIO_YEAR]
86+
if ([version]$CUDA_VERSION_FULL -lt [version]$MINIMUM_CUDA_VERSION) {
87+
Write-Output "Error: Visual Studio $($VISUAL_STUDIO_YEAR) requires CUDA >= $($MINIMUM_CUDA_VERSION)"
88+
exit 1
89+
}
90+
}
9091
} else {
91-
Write-Output "Warning: Unknown Visual Studio Version. CUDA version may be insufficient."
92+
Write-Output "Warning: Unknown Visual Studio Version. CUDA version may be insufficient."
9293
}
9394

9495
## ------------------------------------------------
@@ -97,21 +98,20 @@ if ($VISUAL_STUDIO.length -ge 4) {
9798

9899
$CUDA_PACKAGES = ""
99100

100-
# for CUDA >= 11 cudart is a required package.
101-
# if([version]$CUDA_VERSION_FULL -ge [version]"11.0") {
102-
# if(-not $CUDA_PACKAGES_IN -contains "cudart") {
103-
# $CUDA_PACKAGES_IN += 'cudart'
104-
# }
105-
# }
101+
if([version]$CUDA_VERSION_FULL -ge [version]"13.0.0") {
102+
$CUDA_PACKAGES_IN += "crt"
103+
$CUDA_PACKAGES_IN += "nvptxcompiler"
104+
$CUDA_PACKAGES_IN += "nvvm"
105+
}
106106

107107
foreach ($package in $CUDA_PACKAGES_IN) {
108-
# Make sure the correct package name is used for nvcc.
109-
if($package -eq "nvcc" -and [version]$CUDA_VERSION_FULL -lt [version]"9.1"){
110-
$package="compiler"
111-
} elseif($package -eq "compiler" -and [version]$CUDA_VERSION_FULL -ge [version]"9.1") {
112-
$package="nvcc"
113-
}
114-
$CUDA_PACKAGES += " $($package)_$($CUDA_MAJOR).$($CUDA_MINOR)"
108+
# Make sure the correct package name is used for nvcc.
109+
if($package -eq "nvcc" -and [version]$CUDA_VERSION_FULL -lt [version]"9.1"){
110+
$package="compiler"
111+
} elseif($package -eq "compiler" -and [version]$CUDA_VERSION_FULL -ge [version]"9.1") {
112+
$package="nvcc"
113+
}
114+
$CUDA_PACKAGES += " $($package)_$($CUDA_MAJOR).$($CUDA_MINOR)"
115115
}
116116
echo "$($CUDA_PACKAGES)"
117117
## -----------------
@@ -121,12 +121,13 @@ echo "$($CUDA_PACKAGES)"
121121
# Select the download link if known, otherwise have a guess.
122122
$CUDA_REPO_PKG_REMOTE=""
123123
if ($CUDA_KNOWN_URLS.containsKey($CUDA_VERSION_FULL)){
124-
$CUDA_REPO_PKG_REMOTE=$CUDA_KNOWN_URLS[$CUDA_VERSION_FULL]
124+
$CUDA_REPO_PKG_REMOTE=$CUDA_KNOWN_URLS[$CUDA_VERSION_FULL]
125125
} else {
126-
# Guess what the url is given the most recent pattern (at the time of writing, 10.1)
127-
Write-Output "note: URL for CUDA ${$CUDA_VERSION_FULL} not known, estimating."
128-
$CUDA_REPO_PKG_REMOTE="http://developer.download.nvidia.com/compute/cuda/$($CUDA_MAJOR).$($CUDA_MINOR)/Prod/network_installers/cuda_$($CUDA_VERSION_FULL)_win10_network.exe"
126+
# Guess what the url is given the most recent pattern (at the time of writing, 10.1)
127+
Write-Output "note: URL for CUDA ${$CUDA_VERSION_FULL} not known, estimating."
128+
$CUDA_REPO_PKG_REMOTE="http://developer.download.nvidia.com/compute/cuda/$($CUDA_MAJOR).$($CUDA_MINOR)/Prod/network_installers/cuda_$($CUDA_VERSION_FULL)_win10_network.exe"
129129
}
130+
130131
$CUDA_REPO_PKG_LOCAL="cuda_$($CUDA_VERSION_FULL)_win10_network.exe"
131132

132133

@@ -138,10 +139,10 @@ $CUDA_REPO_PKG_LOCAL="cuda_$($CUDA_VERSION_FULL)_win10_network.exe"
138139
Write-Output "Downloading CUDA Network Installer for $($CUDA_VERSION_FULL) from: $($CUDA_REPO_PKG_REMOTE)"
139140
Invoke-WebRequest $CUDA_REPO_PKG_REMOTE -OutFile $CUDA_REPO_PKG_LOCAL | Out-Null
140141
if (Test-Path -Path $CUDA_REPO_PKG_LOCAL){
141-
Write-Output "Downloading Complete"
142+
Write-Output "Downloading Complete"
142143
} else {
143-
Write-Output "Error: Failed to download $($CUDA_REPO_PKG_LOCAL) from $($CUDA_REPO_PKG_REMOTE)"
144-
exit 1
144+
Write-Output "Error: Failed to download $($CUDA_REPO_PKG_LOCAL) from $($CUDA_REPO_PKG_REMOTE)"
145+
exit 1
145146
}
146147

147148
# Invoke silent install of CUDA (via network installer)
@@ -150,8 +151,8 @@ Start-Process -Wait -FilePath .\"$($CUDA_REPO_PKG_LOCAL)" -ArgumentList "-s $($C
150151

151152
# Check the return status of the CUDA installer.
152153
if (!$?) {
153-
Write-Output "Error: CUDA installer reported error. $($LASTEXITCODE)"
154-
exit 1
154+
Write-Output "Error: CUDA installer reported error. $($LASTEXITCODE)"
155+
exit 1
155156
}
156157

157158
# Store the CUDA_PATH in the environment for the current session, to be forwarded in the action.
@@ -170,9 +171,9 @@ Write-Output "CUDA_PATH_VX_Y $($CUDA_PATH_VX_Y)"
170171

171172
# If executing on github actions, emit the appropriate echo statements to update environment variables
172173
if (Test-Path "env:GITHUB_ACTIONS") {
173-
# Set paths for subsequent steps, using $env:CUDA_PATH
174-
echo "Adding CUDA to CUDA_PATH, CUDA_PATH_X_Y and PATH"
175-
echo "CUDA_PATH=$env:CUDA_PATH" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
176-
echo "$env:CUDA_PATH_VX_Y=$env:CUDA_PATH" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
177-
echo "$env:CUDA_PATH/bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
174+
# Set paths for subsequent steps, using $env:CUDA_PATH
175+
echo "Adding CUDA to CUDA_PATH, CUDA_PATH_X_Y and PATH"
176+
echo "CUDA_PATH=$env:CUDA_PATH" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
177+
echo "$env:CUDA_PATH_VX_Y=$env:CUDA_PATH" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
178+
echo "$env:CUDA_PATH/bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
178179
}

src/common_host.cu

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,7 @@
4141

4242
namespace tcnn {
4343

44-
static_assert(
45-
__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2),
46-
"tiny-cuda-nn requires at least CUDA 10.2"
47-
);
44+
static_assert(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2), "tiny-cuda-nn requires at least CUDA 10.2");
4845

4946
std::function<void(LogSeverity, const std::string&)> g_log_callback = [](LogSeverity severity, const std::string& msg) {
5047
switch (severity) {
@@ -214,9 +211,7 @@ int cuda_device() {
214211
return device;
215212
}
216213

217-
void set_cuda_device(int device) {
218-
CUDA_CHECK_THROW(cudaSetDevice(device));
219-
}
214+
void set_cuda_device(int device) { CUDA_CHECK_THROW(cudaSetDevice(device)); }
220215

221216
int cuda_device_count() {
222217
int device_count;
@@ -244,9 +239,7 @@ const cudaDeviceProp& cuda_get_device_properties(int device) {
244239
return cuda_device_properties().at(device);
245240
}
246241

247-
std::string cuda_device_name(int device) {
248-
return cuda_get_device_properties(device).name;
249-
}
242+
std::string cuda_device_name(int device) { return cuda_get_device_properties(device).name; }
250243

251244
uint32_t cuda_compute_capability(int device) {
252245
const auto& props = cuda_get_device_properties(device);
@@ -261,22 +254,20 @@ uint32_t cuda_max_supported_compute_capability() {
261254
return 80;
262255
} else if (cuda_version < 11080) {
263256
return 86;
264-
} else {
257+
} else if (cuda_version < 12080) {
265258
return 90;
259+
} else {
260+
return 120;
266261
}
267262
}
268263

269264
uint32_t cuda_supported_compute_capability(int device) {
270265
return std::min(cuda_compute_capability(device), cuda_max_supported_compute_capability());
271266
}
272267

273-
size_t cuda_max_shmem(int device) {
274-
return cuda_get_device_properties(device).sharedMemPerBlockOptin;
275-
}
268+
size_t cuda_max_shmem(int device) { return cuda_get_device_properties(device).sharedMemPerBlockOptin; }
276269

277-
uint32_t cuda_max_registers(int device) {
278-
return (uint32_t)cuda_get_device_properties(device).regsPerBlock;
279-
}
270+
uint32_t cuda_max_registers(int device) { return (uint32_t)cuda_get_device_properties(device).regsPerBlock; }
280271

281272
size_t cuda_memory_granularity(int device) {
282273
size_t granularity;
@@ -358,4 +349,4 @@ template <> std::string type_to_string<double>() { return "double"; }
358349
template <> std::string type_to_string<float>() { return "float"; }
359350
template <> std::string type_to_string<__half>() { return "__half"; }
360351

361-
}
352+
} // namespace tcnn

0 commit comments

Comments
 (0)