Skip to content

Commit 003af4b

Browse files
godnight10061Godnight1006
andauthored
Add Triton installation for Windows CUDA, Linux ROCm/XPU (#31)
* feat: install Triton on more platforms * fix: declare packaging dependency * feat: add torch.compile triton self-test --------- Co-authored-by: Godnight1006 <[email protected]>
1 parent b0e3294 commit 003af4b

File tree

7 files changed

+147
-11
lines changed

7 files changed

+147
-11
lines changed

API.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ Or you can use the library:
1313
torchruntime.install(["torch", "torchvision<0.20"])
1414
```
1515

16+
On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`).
17+
18+
## Test torch
19+
Run:
20+
`python -m torchruntime test`
21+
22+
To specifically verify `torch.compile` / Triton:
23+
`python -m torchruntime test compile`
24+
1625
## Get device info
1726
You can use the device database built into `torchruntime` for your projects:
1827
```py

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ Supports Windows, Linux, and Mac.
3030

3131
This will install `torch`, `torchvision`, and `torchaudio`, and will decide the variant based on the user's OS, GPU manufacturer and GPU model number. See [customizing packages](#customizing-packages) for more options.
3232

33+
On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`).
34+
3335
**Tip:** You can also add the `--uv` flag to install packages using [uv](https://docs.astral.sh/uv/) (instead of `pip`). For e.g. `python -m torchruntime install --uv`
3436

3537
### Step 2. Configure torch
@@ -42,7 +44,7 @@ torchruntime.configure()
4244
```
4345

4446
### (Optional) Step 3. Test torch
45-
Run `python -m torchruntime test` to run a set of tests to check whether the installed version of torch is working correctly.
47+
Run `python -m torchruntime test` to run a set of tests to check whether the installed version of torch is working correctly (including a `torch.compile` / Triton check on CUDA/XPU systems). You can also run `python -m torchruntime test compile` to run only the compile check.
4648

4749
## Customizing packages
4850
By default, `python -m torchruntime install` will install the latest available `torch`, `torchvision` and `torchaudio` suitable on the user's platform.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import setuptools
22

33
setuptools.setup(
4-
install_requires=[],
4+
install_requires=["packaging"],
55
)

tests/test_installer.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,57 @@ def test_cpu_platform():
1616
assert result == [packages]
1717

1818

19-
def test_cuda_platform():
19+
def test_cuda_platform(monkeypatch):
20+
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
2021
packages = ["torch", "torchvision"]
2122
result = get_install_commands("cu112", packages)
2223
expected_url = "https://download.pytorch.org/whl/cu112"
2324
assert result == [packages + ["--index-url", expected_url]]
2425

2526

26-
def test_cuda_nightly_platform():
27+
def test_cuda_platform_windows_installs_triton(monkeypatch):
28+
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
29+
packages = ["torch", "torchvision"]
30+
result = get_install_commands("cu112", packages)
31+
expected_url = "https://download.pytorch.org/whl/cu112"
32+
assert result == [packages + ["--index-url", expected_url], ["triton-windows"]]
33+
34+
35+
def test_cuda_nightly_platform(monkeypatch):
36+
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
2737
packages = ["torch", "torchvision"]
2838
result = get_install_commands("nightly/cu112", packages)
2939
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
3040
assert result == [packages + ["--index-url", expected_url]]
3141

3242

43+
def test_cuda_nightly_platform_windows_installs_triton(monkeypatch):
44+
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
45+
packages = ["torch", "torchvision"]
46+
result = get_install_commands("nightly/cu112", packages)
47+
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
48+
assert result == [packages + ["--index-url", expected_url], ["triton-windows"]]
49+
50+
3351
def test_rocm_platform():
3452
packages = ["torch", "torchvision"]
3553
result = get_install_commands("rocm4.2", packages)
3654
expected_url = "https://download.pytorch.org/whl/rocm4.2"
3755
assert result == [packages + ["--index-url", expected_url]]
3856

3957

58+
def test_rocm_platform_linux_installs_triton(monkeypatch):
59+
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
60+
packages = ["torch", "torchvision"]
61+
result = get_install_commands("rocm6.2", packages)
62+
expected_url = "https://download.pytorch.org/whl/rocm6.2"
63+
triton_index_url = "https://download.pytorch.org/whl"
64+
assert result == [
65+
packages + ["--index-url", expected_url],
66+
["pytorch-triton-rocm", "--index-url", triton_index_url],
67+
]
68+
69+
4070
def test_xpu_platform_windows_with_torch_only(monkeypatch):
4171
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
4272
packages = ["torch"]
@@ -60,7 +90,11 @@ def test_xpu_platform_linux(monkeypatch):
6090
packages = ["torch", "torchvision"]
6191
result = get_install_commands("xpu", packages)
6292
expected_url = "https://download.pytorch.org/whl/test/xpu"
63-
assert result == [packages + ["--index-url", expected_url]]
93+
triton_index_url = "https://download.pytorch.org/whl"
94+
assert result == [
95+
packages + ["--index-url", expected_url],
96+
["pytorch-triton-xpu", "--index-url", triton_index_url],
97+
]
6498

6599

66100
def test_directml_platform():

torchruntime/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def print_usage(entry_command: str):
1010
1111
Commands:
1212
install Install PyTorch packages
13-
test [subcommand] Run tests (subcommands: all, devices, math, functions)
13+
test [subcommand] Run tests (subcommands: all, import, devices, compile, math, functions)
1414
--help Show this help message
1515
1616
Examples:
@@ -20,10 +20,11 @@ def print_usage(entry_command: str):
2020
{entry_command} install --uv torch>=2.0.0 torchaudio
2121
{entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0
2222
23-
{entry_command} test # Runs all tests (import, devices, math, functions)
23+
{entry_command} test # Runs all tests (import, devices, compile, math, functions)
2424
{entry_command} test all # Same as above
2525
{entry_command} test import # Test only import
2626
{entry_command} test devices # Test only devices
27+
{entry_command} test compile # Test torch.compile (Triton)
2728
{entry_command} test math # Test only math
2829
{entry_command} test functions # Test only functions
2930

torchruntime/installer.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
PIP_PREFIX = [sys.executable, "-m", "pip", "install"]
1313
CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$")
1414
ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$")
15+
ROCM_VERSION_REGEX = re.compile(r"^(?:nightly/)?rocm(?P<major>\d+)\.(?P<minor>\d+)$")
1516

1617

1718
def get_install_commands(torch_platform, packages):
@@ -43,6 +44,9 @@ def get_install_commands(torch_platform, packages):
4344
- For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds.
4445
- For "directml", the "torch-directml" package is returned as part of the installation commands.
4546
- For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands.
47+
- For Windows CUDA, the function also installs "triton-windows" (for torch.compile and Triton kernels).
48+
- For Linux ROCm 6.x, the function also installs "pytorch-triton-rocm".
49+
- For Linux XPU, the function also installs "pytorch-triton-xpu".
4650
"""
4751
if not packages:
4852
packages = ["torch", "torchaudio", "torchvision"]
@@ -52,7 +56,17 @@ def get_install_commands(torch_platform, packages):
5256

5357
if CUDA_REGEX.match(torch_platform) or ROCM_REGEX.match(torch_platform):
5458
index_url = f"https://download.pytorch.org/whl/{torch_platform}"
55-
return [packages + ["--index-url", index_url]]
59+
cmds = [packages + ["--index-url", index_url]]
60+
61+
if os_name == "Windows" and CUDA_REGEX.match(torch_platform):
62+
cmds.append(["triton-windows"])
63+
64+
if os_name == "Linux" and ROCM_REGEX.match(torch_platform):
65+
match = ROCM_VERSION_REGEX.match(torch_platform)
66+
if match and int(match.group("major")) >= 6:
67+
cmds.append(["pytorch-triton-rocm", "--index-url", "https://download.pytorch.org/whl"])
68+
69+
return cmds
5670

5771
if torch_platform == "xpu":
5872
if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages):
@@ -65,7 +79,10 @@ def get_install_commands(torch_platform, packages):
6579
else:
6680
index_url = f"https://download.pytorch.org/whl/test/{torch_platform}"
6781

68-
return [packages + ["--index-url", index_url]]
82+
cmds = [packages + ["--index-url", index_url]]
83+
if os_name == "Linux":
84+
cmds.append(["pytorch-triton-xpu", "--index-url", "https://download.pytorch.org/whl"])
85+
return cmds
6986

7087
if torch_platform == "directml":
7188
return [["torch-directml"], packages]

torchruntime/utils/torch_test/__init__.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import importlib.util
2+
import platform
13
import time
24

3-
from ..torch_device_utils import get_installed_torch_platform, get_device_count, get_device_name, get_device
5+
from ..torch_device_utils import get_device, get_device_count, get_device_name, get_installed_torch_platform
46

57

68
def test(subcommand):
@@ -16,7 +18,7 @@ def test(subcommand):
1618

1719

1820
def test_all():
19-
for fn in (test_import, test_devices, test_math, test_functions):
21+
for fn in (test_import, test_devices, test_compile, test_math, test_functions):
2022
fn()
2123
print("")
2224

@@ -101,3 +103,74 @@ def test_functions():
101103
t.run_all_tests()
102104

103105
print("--- / FUNCTIONAL TEST ---")
106+
107+
108+
def test_compile():
109+
print("--- COMPILE TEST ---")
110+
111+
try:
112+
import torch
113+
except ImportError:
114+
print("torch.compile: SKIPPED (torch not installed)")
115+
print("--- / COMPILE TEST ---")
116+
return
117+
118+
if not hasattr(torch, "compile"):
119+
print("torch.compile: SKIPPED (requires torch>=2.0)")
120+
print("--- / COMPILE TEST ---")
121+
return
122+
123+
torch_platform_name, _ = get_installed_torch_platform()
124+
if torch_platform_name not in ("cuda", "xpu"):
125+
print(f"torch.compile: SKIPPED (unsupported backend: {torch_platform_name})")
126+
print("--- / COMPILE TEST ---")
127+
return
128+
129+
if importlib.util.find_spec("triton") is None:
130+
print("triton: NOT INSTALLED")
131+
else:
132+
print("triton: installed")
133+
134+
device = get_device(0)
135+
print("On torch device:", device)
136+
137+
def f(x):
138+
return x * 2 + 1
139+
140+
try:
141+
compiled_f = torch.compile(f)
142+
x = torch.randn((1024,), device=device)
143+
y = compiled_f(x)
144+
expected = f(x)
145+
if not torch.allclose(y, expected):
146+
print("torch.compile: FAILED (output mismatch)")
147+
else:
148+
if torch_platform_name == "cuda":
149+
torch.cuda.synchronize()
150+
if torch_platform_name == "xpu" and hasattr(torch, "xpu") and hasattr(torch.xpu, "synchronize"):
151+
torch.xpu.synchronize()
152+
print("torch.compile: PASSED")
153+
except Exception as e:
154+
print(f"torch.compile: FAILED ({type(e).__name__}: {e})")
155+
156+
hint = None
157+
os_name = platform.system()
158+
if torch_platform_name == "cuda" and os_name == "Windows":
159+
hint = "pip install triton-windows (or: python -m torchruntime install)"
160+
elif torch_platform_name == "cuda" and os_name == "Linux":
161+
if getattr(torch.version, "hip", None):
162+
hint = (
163+
"pip install pytorch-triton-rocm --index-url https://download.pytorch.org/whl "
164+
"(or: python -m torchruntime install)"
165+
)
166+
elif torch_platform_name == "xpu" and os_name == "Linux":
167+
hint = (
168+
"pip install pytorch-triton-xpu --index-url https://download.pytorch.org/whl "
169+
"(or: python -m torchruntime install)"
170+
)
171+
172+
if hint:
173+
print("If this failed due to Triton, try:")
174+
print(" ", hint)
175+
176+
print("--- / COMPILE TEST ---")

0 commit comments

Comments
 (0)