Skip to content

Commit fce59bd

Browse files
godnight10061Godnight1006
andauthored
Fix cu128 index selection for pinned Torch versions (#30)
* Fix CUDA platform selection for pinned torch versions When a system would normally use cu128, but the requested torch/torchvision/torchaudio versions are capped below the first cu128 wheels, automatically use cu124 instead. Adds a regression test for issue #16. * Move CUDA package-based demotion to platform detection * Inline cu128 demotion check --------- Co-authored-by: Godnight1006 <[email protected]>
1 parent 2e96284 commit fce59bd

File tree

3 files changed

+67
-5
lines changed

3 files changed

+67
-5
lines changed

tests/test_platform_detection.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,23 @@ def test_nvidia_gpu_linux(monkeypatch):
121121
assert get_torch_platform(gpu_infos) == expected
122122

123123

124+
def test_nvidia_gpu_demotes_to_cu124_for_pinned_torch_below_2_7(monkeypatch):
125+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
126+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
127+
monkeypatch.setattr("torchruntime.platform_detection.py_version", (3, 11))
128+
monkeypatch.setattr("torchruntime.platform_detection.get_nvidia_arch", lambda device_names: 8.6)
129+
130+
gpu_infos = [GPU(NVIDIA, "NVIDIA", 0x1234, "GeForce", True)]
131+
132+
assert get_torch_platform(gpu_infos) == "cu128"
133+
assert get_torch_platform(gpu_infos, packages=["torch==2.6.0"]) == "cu124"
134+
assert get_torch_platform(gpu_infos, packages=["torch<2.7.0"]) == "cu124"
135+
assert get_torch_platform(gpu_infos, packages=["torch<=2.7.0"]) == "cu128"
136+
assert get_torch_platform(gpu_infos, packages=["torch!=2.7.0"]) == "cu128"
137+
assert get_torch_platform(gpu_infos, packages=["torch>=2.7.0,!=2.7.0,!=2.7.1,<2.8.0"]) == "cu128"
138+
assert get_torch_platform(gpu_infos, packages=["torchvision==0.21.0"]) == "cu124"
139+
140+
124141
def test_nvidia_gpu_mac(monkeypatch):
125142
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
126143
monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")

torchruntime/installer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def install(packages=[], use_uv=False):
9898
"""
9999

100100
gpu_infos = get_gpus()
101-
torch_platform = get_torch_platform(gpu_infos)
101+
torch_platform = get_torch_platform(gpu_infos, packages=packages)
102102
cmds = get_install_commands(torch_platform, packages)
103103
cmds = get_pip_commands(cmds, use_uv=use_uv)
104104
run_commands(cmds)

torchruntime/platform_detection.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,63 @@
1-
import re
21
import sys
32
import platform
43

4+
from packaging.requirements import Requirement
5+
from packaging.version import Version
6+
57
from .gpu_db import get_nvidia_arch, get_amd_gfx_info
68
from .consts import AMD, INTEL, NVIDIA, CONTACT_LINK
79

810
os_name = platform.system()
911
arch = platform.machine().lower()
1012
py_version = sys.version_info
1113

14+
_CUDA_12_8_MIN_VERSIONS = {
15+
"torch": Version("2.7.0"),
16+
"torchaudio": Version("2.7.0"),
17+
"torchvision": Version("0.22.0"),
18+
}
19+
20+
21+
def _packages_require_cuda_12_4(packages):
22+
if not packages:
23+
return False
24+
25+
for package in packages:
26+
try:
27+
requirement = Requirement(package)
28+
except Exception:
29+
continue
30+
31+
name = requirement.name.lower().replace("_", "-")
32+
threshold = _CUDA_12_8_MIN_VERSIONS.get(name)
33+
if not threshold or not requirement.specifier:
34+
continue
35+
36+
test_versions = [
37+
threshold,
38+
Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 1}"),
39+
Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 2}"),
40+
Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 3}"),
41+
Version(f"{threshold.major}.{threshold.minor + 1}.0"),
42+
Version(f"{threshold.major + 1}.0.0"),
43+
]
44+
45+
allows_threshold_or_higher = any(
46+
requirement.specifier.contains(str(version), prereleases=True) for version in test_versions
47+
)
48+
if not allows_threshold_or_higher:
49+
return True
50+
51+
return False
1252

13-
def get_torch_platform(gpu_infos):
53+
54+
def get_torch_platform(gpu_infos, packages=[]):
1455
"""
1556
Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information.
1657
1758
Args:
1859
gpu_infos (list of `torchruntime.device_db.GPU` instances)
60+
packages (list of str): Optional list of torch/torchvision/torchaudio requirement strings.
1961
2062
Returns:
2163
str: A string representing the platform to use. Possible values:
@@ -53,12 +95,12 @@ def get_torch_platform(gpu_infos):
5395
integrated_devices.append(device)
5496

5597
if discrete_devices:
56-
return _get_platform_for_discrete(discrete_devices)
98+
return _get_platform_for_discrete(discrete_devices, packages=packages)
5799

58100
return _get_platform_for_integrated(integrated_devices)
59101

60102

61-
def _get_platform_for_discrete(gpu_infos):
103+
def _get_platform_for_discrete(gpu_infos, packages=None):
62104
vendor_ids = set(gpu.vendor_id for gpu in gpu_infos)
63105

64106
if len(vendor_ids) > 1:
@@ -126,6 +168,9 @@ def _get_platform_for_discrete(gpu_infos):
126168
if (arch_version > 3.7 and arch_version < 7.5) or py_version < (3, 9):
127169
return "cu124"
128170

171+
if _packages_require_cuda_12_4(packages):
172+
return "cu124"
173+
129174
return "cu128"
130175
elif os_name == "Darwin":
131176
raise NotImplementedError(

0 commit comments

Comments
 (0)