Skip to content

Commit 280c57e

Browse files
authored
fix: working README.md example, support nvfuser for torch==2.8 (#2525)
1 parent ac24932 commit 280c57e

File tree

4 files changed

+62
-30
lines changed

4 files changed

+62
-30
lines changed

README.md

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,33 +77,50 @@ For **performance experts**, Thunder is the most ergonomic framework for underst
7777
Install Thunder via pip ([more options](https://lightning.ai/docs/thunder/latest/fundamentals/installation.html)):
7878

7979
```bash
80-
pip install torch==2.6.0 torchvision==0.21 nvfuser-cu124-torch26
81-
8280
pip install lightning-thunder
81+
82+
pip install -U torch torchvision
83+
pip install nvfuser-cu128-torch28 nvidia-cudnn-frontend # if NVIDIA GPU is present
8384
```
8485

8586
<details>
86-
<summary>Advanced install options</summary>
87+
<summary>For older versions of <code>torch</code></summary>
88+
89+
<code>torch==2.7</code> + CUDA 12.8
90+
91+
```bash
92+
pip install lightning-thunder
8793

88-
### Blackwell support
94+
pip install torch==2.7.0 torchvision==0.22
95+
pip install nvfuser-cu128-torch27 nvidia-cudnn-frontend # if NVIDIA GPU is present
96+
```
8997

90-
For Blackwell you'll need CUDA 12.8
98+
<code>torch==2.6</code> + CUDA 12.6
9199

92100
```bash
93-
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
94-
pip install --pre nvfuser-cu128 --extra-index-url https://pypi.nvidia.com
101+
pip install lightning-thunder
95102

103+
pip install torch==2.6.0 torchvision==0.21
104+
pip install nvfuser-cu126-torch26 nvidia-cudnn-frontend # if NVIDIA GPU is present
105+
```
106+
107+
<code>torch==2.5</code> + CUDA 12.4
108+
109+
```bash
96110
pip install lightning-thunder
111+
112+
pip install torch==2.5.0 torchvision==0.20
113+
pip install nvfuser-cu124-torch25 nvidia-cudnn-frontend # if NVIDIA GPU is present
97114
```
98115

99-
### Install additional executors
116+
</details>
100117

101-
These are optional, feel free to mix and match
118+
<details>
119+
<summary>Advanced install options</summary>
102120

103-
```bash
104-
# cuDNN SDPA
105-
pip install nvidia-cudnn-frontend
121+
### Install optional executors
106122

123+
```bash
107124
# Float8 support (this will compile from source, be patient)
108125
pip install "transformer_engine[pytorch]"
109126
```

thunder/executors/transformer_engineex.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,21 @@
88
from thunder.extend import StatefulExecutor
99
from thunder.core.trace import TraceCtx
1010

11+
import torch
12+
1113
__all__ = ["transformer_engine_ex", "TransformerEngineTransform", "_te_activation_checkpointing_transform"]
1214

1315
transformer_engine_ex: None | StatefulExecutor = None
1416
TransformerEngineTransform: None | Transform = None
1517
_te_activation_checkpointing_transform: None | Callable[[TraceCtx], TraceCtx] = None
1618

17-
if package_available("transformer_engine"):
18-
import thunder.executors.transformer_engineex_impl as impl
19+
if torch.cuda.is_available():
20+
if package_available("transformer_engine"):
21+
import thunder.executors.transformer_engineex_impl as impl
1922

20-
transformer_engine_ex = impl.transformer_engine_ex
21-
TransformerEngineTransform = impl.TransformerEngineTransform
22-
_te_activation_checkpointing_transform = impl._te_activation_checkpointing_transform
23+
transformer_engine_ex = impl.transformer_engine_ex
24+
TransformerEngineTransform = impl.TransformerEngineTransform
25+
_te_activation_checkpointing_transform = impl._te_activation_checkpointing_transform
2326

24-
else:
25-
warnings.warn("transformer_engine module not found!")
27+
else:
28+
warnings.warn("transformer_engine module not found!")
Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from thunder.executors import triton_utils
22
from thunder.extend import OperatorExecutor
33

4+
import torch
5+
46
triton_version: None | str = triton_utils.triton_version()
57

68
triton_ex: None | OperatorExecutor = None
7-
if triton_version is not None:
8-
try:
9-
from thunder.executors.triton_crossentropy_impl import triton_ex as impl_ex
109

11-
triton_ex = impl_ex
12-
except Exception:
13-
import warnings
10+
if torch.cuda.is_available():
11+
if triton_version is not None:
12+
try:
13+
from thunder.executors.triton_crossentropy_impl import triton_ex as impl_ex
14+
15+
triton_ex = impl_ex
16+
except Exception:
17+
import warnings
1418

15-
warnings.warn("triton is present but cannot be initialized")
16-
triton_version = None
19+
warnings.warn("triton is present but cannot be initialized")
20+
triton_version = None

thunder/recipes/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def get_nvfuser_package_hint() -> str:
1313
"2.5": "nvfuser-cu124-torch25",
1414
"2.6": "nvfuser-cu126-torch26",
1515
"2.7": "nvfuser-cu128-torch27",
16+
"2.8": "nvfuser-cu128-torch28",
1617
}
1718

1819
torch_key = ".".join(torch_version.split(".")[:2])
@@ -73,8 +74,17 @@ def __init__(
7374
plugins=None,
7475
):
7576
super().__init__(interpreter=interpreter, plugins=plugins)
76-
self.executor_names = ["cudnn", "sdpa", "torchcompile_xentropy"]
7777
self.fuser = fuser
78+
self.executor_names = []
79+
80+
if torch.cuda.is_available():
81+
self.executor_names = ["cudnn", "sdpa"]
82+
if self.fuser == "nvfuser":
83+
self.executor_names.append("torchcompile_xentropy")
84+
else:
85+
print("GPU not found, nvFuser not available. Setting fusing executor to torch.compile")
86+
self.fuser = "torch.compile"
87+
7888
self.setup_fuser()
7989
self.show_progress = show_progress
8090

@@ -114,8 +124,6 @@ def setup_fuser(self) -> None:
114124
if "nvfuser" not in self.executor_names:
115125
self.executor_names.append("nvfuser")
116126
elif self.fuser == "torch.compile":
117-
if "torchcompile_xentropy" in self.executor_names:
118-
self.executor_names.remove("torchcompile_xentropy")
119127
if "torchcompile" not in self.executor_names:
120128
self.executor_names.append("torchcompile")
121129
else:

0 commit comments

Comments
 (0)