Skip to content

Commit f1358b1

Browse files
Prefer CUDA over DirectML when both available (#28)
Fix get_installed_torch_platform() to prefer torch.cuda when both torch_directml and torch.cuda are available (common in Windows setups that keep torch-directml installed after switching to ROCm/CUDA builds). Co-authored-by: Godnight1006
1 parent f0caa68 commit f1358b1

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

tests/test_torch_device_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import sys
2+
import types
3+
4+
5+
def _make_fake_torch(*, cuda_available: bool):
6+
torch = types.ModuleType("torch")
7+
torch.__path__ = [] # allow importing torch.backends
8+
9+
class _FakeCUDA:
10+
@staticmethod
11+
def is_available():
12+
return cuda_available
13+
14+
torch.cuda = _FakeCUDA()
15+
torch.cpu = object()
16+
17+
backends = types.ModuleType("torch.backends")
18+
backends.__path__ = []
19+
torch.backends = backends
20+
21+
return torch, backends
22+
23+
24+
def test_get_installed_torch_platform_prefers_cuda_over_directml(monkeypatch):
25+
# Regression test for Windows environments where users have both a working CUDA/ROCm torch
26+
# backend AND torch-directml installed. In that scenario we should prefer torch.cuda over
27+
# DirectML to avoid mis-detecting the active backend.
28+
fake_torch, fake_backends = _make_fake_torch(cuda_available=True)
29+
30+
fake_torch_directml = types.ModuleType("torch_directml")
31+
fake_torch_directml.is_available = lambda: True
32+
33+
monkeypatch.setitem(sys.modules, "torch", fake_torch)
34+
monkeypatch.setitem(sys.modules, "torch.backends", fake_backends)
35+
monkeypatch.setitem(sys.modules, "torch_directml", fake_torch_directml)
36+
37+
from torchruntime.utils.torch_device_utils import get_installed_torch_platform
38+
39+
torch_platform_name, torch_platform = get_installed_torch_platform()
40+
assert torch_platform_name == "cuda"
41+
assert torch_platform is fake_torch.cuda
42+
43+
44+
def test_get_installed_torch_platform_uses_directml_when_cuda_unavailable(monkeypatch):
45+
fake_torch, fake_backends = _make_fake_torch(cuda_available=False)
46+
47+
fake_torch_directml = types.ModuleType("torch_directml")
48+
fake_torch_directml.is_available = lambda: True
49+
50+
monkeypatch.setitem(sys.modules, "torch", fake_torch)
51+
monkeypatch.setitem(sys.modules, "torch.backends", fake_backends)
52+
monkeypatch.setitem(sys.modules, "torch_directml", fake_torch_directml)
53+
54+
from torchruntime.utils.torch_device_utils import get_installed_torch_platform
55+
56+
torch_platform_name, _ = get_installed_torch_platform()
57+
assert torch_platform_name == "directml"

torchruntime/utils/torch_device_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ def get_installed_torch_platform():
4242
import torch.backends
4343
from platform import system as os_name
4444

45-
if _is_directml_platform_available():
46-
return DIRECTML, torch.directml
47-
4845
if torch.cuda.is_available():
4946
return CUDA, torch.cuda
5047
if hasattr(torch, XPU) and torch.xpu.is_available():
5148
return XPU, torch.xpu
49+
50+
# DirectML is a useful fallback on Windows, but users can have torch-directml installed
51+
# alongside a working CUDA/ROCm torch build. Prefer the native torch backend when available.
52+
if _is_directml_platform_available():
53+
return DIRECTML, torch.directml
5254
if os_name() == "Darwin":
5355
if hasattr(torch, MPS):
5456
return MPS, torch.mps

0 commit comments

Comments
 (0)