Skip to content

Commit 40e7433

Browse files
pytorchbotatalman
andauthored
Add warning about removed sm50 and sm60 arches (pytorch#158478)
Add warning about removed sm50 and sm60 arches (pytorch#158301) Related to pytorch#157517 Detect when users are executing torch build with cuda 12.8/12.9 and running on Maxwell or Pascal architectures. We would like to include reference to the issue: pytorch#157517 as well as ask people to install CUDA 12.6 builds if they are running on sm50 or sm60 architectures. Test: ``` >>> torch.cuda.get_arch_list() ['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120'] >>> torch.cuda.init() /home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:263: UserWarning: Found <GPU Name> which is of cuda capability 5.0. PyTorch no longer supports this GPU because it is too old. The minimum cuda capability supported by this library is 7.0. warnings.warn( /home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:268: UserWarning: Support for Maxwell and Pascal architectures is removed for CUDA 12.8+ builds. Please see pytorch#157517 Please install CUDA 12.6 builds if you require Maxwell or Pascal support. ``` Pull Request resolved: pytorch#158301 Approved by: https://github.com/nWEIdia, https://github.com/albanD (cherry picked from commit fb731fe) Co-authored-by: atalman <[email protected]>
1 parent 779e6f3 commit 40e7433

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

torch/cuda/__init__.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -244,21 +244,25 @@ def _extract_arch_version(arch_string: str):
244244

245245

246246
def _check_capability():
247-
incorrect_binary_warn = """
248-
Found GPU%d %s which requires CUDA_VERSION >= %d to
249-
work properly, but your PyTorch was compiled
250-
with CUDA_VERSION %d. Please install the correct PyTorch binary
251-
using instructions from https://pytorch.org
252-
""" # noqa: F841
253-
254-
old_gpu_warn = """
247+
incompatible_gpu_warn = """
255248
Found GPU%d %s which is of cuda capability %d.%d.
256-
PyTorch no longer supports this GPU because it is too old.
257-
The minimum cuda capability supported by this library is %d.%d.
249+
Minimum and Maximum cuda capability supported by this version of PyTorch is
250+
(%d.%d) - (%d.%d)
258251
"""
252+
matched_cuda_warn = """
253+
Please install PyTorch with a following CUDA
254+
configurations: {} following instructions at
255+
https://pytorch.org/get-started/locally/
256+
"""
257+
258+
# Binary CUDA_ARCHES SUPPORTED by PyTorch
259+
CUDA_ARCHES_SUPPORTED = {
260+
"12.6": {"min": 50, "max": 90},
261+
"12.8": {"min": 70, "max": 120},
262+
"12.9": {"min": 70, "max": 120},
263+
}
259264

260265
if torch.version.cuda is not None: # on ROCm we don't want this check
261-
CUDA_VERSION = torch._C._cuda_getCompiledVersion() # noqa: F841
262266
for d in range(device_count()):
263267
capability = get_device_capability(d)
264268
major = capability[0]
@@ -267,13 +271,35 @@ def _check_capability():
267271
current_arch = major * 10 + minor
268272
min_arch = min(
269273
(_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()),
270-
default=35,
274+
default=50,
271275
)
272-
if current_arch < min_arch:
276+
max_arch = max(
277+
(_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()),
278+
default=50,
279+
)
280+
if current_arch < min_arch or current_arch > max_arch:
273281
warnings.warn(
274-
old_gpu_warn
275-
% (d, name, major, minor, min_arch // 10, min_arch % 10)
282+
incompatible_gpu_warn
283+
% (
284+
d,
285+
name,
286+
major,
287+
minor,
288+
min_arch // 10,
289+
min_arch % 10,
290+
max_arch // 10,
291+
max_arch % 10,
292+
)
276293
)
294+
matched_arches = ""
295+
for arch, arch_info in CUDA_ARCHES_SUPPORTED.items():
296+
if (
297+
current_arch >= arch_info["min"]
298+
and current_arch <= arch_info["max"]
299+
):
300+
matched_arches += f" {arch}"
301+
if matched_arches != "":
302+
warnings.warn(matched_cuda_warn.format(matched_arches))
277303

278304

279305
def _check_cubins():

0 commit comments

Comments
 (0)