Skip to content

Commit 380fc9f

Browse files
committed
Add TileGym executor for SDPA and RMSNorm
This adds an optional `tilegym` executor that can dispatch SDPA (prefill/decode) and RMSNorm to TileGym kernels under conservative checkers. Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
1 parent fb989d4 commit 380fc9f

File tree

4 files changed

+330
-0
lines changed

4 files changed

+330
-0
lines changed

thunder/executors/tilegymex.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import warnings
2+
3+
from lightning_utilities.core.imports import package_available
4+
5+
from thunder import Transform
6+
from thunder.extend import OperatorExecutor
7+
8+
__all__ = ["tilegym_ex", "TileGymTransform"]
9+
10+
11+
tilegym_ex: None | OperatorExecutor = None
12+
TileGymTransform: None | Transform = None
13+
14+
15+
if package_available("tilegym"):
16+
import thunder.executors.tilegymex_impl as impl
17+
18+
tilegym_ex = impl.tilegym_ex
19+
TileGymTransform = impl.TileGymTransform
20+
else:
21+
warnings.warn("tilegym module not found!")
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import TYPE_CHECKING
5+
6+
import torch
7+
from lightning_utilities.core.imports import package_available
8+
9+
import thunder.core.devices as devices
10+
import thunder.core.dtypes as dtypes
11+
from thunder.core.proxies import pyval
12+
from thunder.extend import OperatorExecutor, register_executor
13+
from thunder import Transform
14+
import thunder.torch as ltorch
15+
16+
if TYPE_CHECKING:
17+
from thunder.torch import TensorLike
18+
19+
20+
if not package_available("tilegym"):
21+
raise ImportError("tilegym is required for the tilegym executor")
22+
23+
import tilegym
24+
from tilegym import ops as tg_ops
25+
26+
27+
tilegym_ex: OperatorExecutor = OperatorExecutor("tilegym", version=getattr(tilegym, "__version__", None))
28+
register_executor(tilegym_ex)
29+
30+
31+
def _is_cuda_tensor(t: TensorLike) -> bool:
32+
return t.device.devicetype == devices.DeviceType.CUDA
33+
34+
35+
def _pybool(x) -> bool:
36+
try:
37+
return bool(pyval(x))
38+
except Exception:
39+
return False
40+
41+
42+
def _pyfloat_or_none(x) -> float | None:
43+
if x is None:
44+
return None
45+
try:
46+
return float(pyval(x))
47+
except Exception:
48+
return None
49+
50+
51+
def _parse_min_cc(s: str) -> tuple[int, int] | None:
52+
# Accept "10.0", "10,0", or "100" (treated as "10.0").
53+
s = (s or "").strip()
54+
if not s:
55+
return None
56+
if "." in s:
57+
a, b = s.split(".", 1)
58+
return int(a), int(b)
59+
if "," in s:
60+
a, b = s.split(",", 1)
61+
return int(a), int(b)
62+
if s.isdigit():
63+
if len(s) >= 2:
64+
return int(s[:-1]), int(s[-1])
65+
return int(s), 0
66+
return None
67+
68+
69+
def _tilegym_device_cc_ok(device_index: int) -> bool:
70+
# Default to Blackwell+ (SM100). Override via env vars:
71+
# - THUNDER_TILEGYM_ALLOW_ANY_CC=1 (bypass)
72+
# - THUNDER_TILEGYM_MIN_CC=10.0 (set minimum)
73+
if os.environ.get("THUNDER_TILEGYM_ALLOW_ANY_CC", "0").lower() in ("1", "true", "yes", "y", "on"):
74+
return True
75+
76+
min_cc = _parse_min_cc(os.environ.get("THUNDER_TILEGYM_MIN_CC", "10.0"))
77+
if min_cc is None:
78+
min_cc = (10, 0)
79+
80+
if not torch.cuda.is_available():
81+
return False
82+
try:
83+
cc = torch.cuda.get_device_capability(device_index)
84+
except Exception:
85+
return False
86+
87+
return tuple(cc) >= tuple(min_cc)
88+
89+
90+
def _tilegym_sdpa_checker(
91+
query: TensorLike,
92+
key: TensorLike,
93+
value: TensorLike,
94+
attn_mask: TensorLike | None = None,
95+
dropout_p: float = 0.0,
96+
is_causal: bool = False,
97+
*,
98+
scale: float | None = None,
99+
) -> bool:
100+
# TileGym kernels are CUDA-only.
101+
if not (_is_cuda_tensor(query) and _is_cuda_tensor(key) and _is_cuda_tensor(value)):
102+
return False
103+
104+
if not _tilegym_device_cc_ok(query.device.index):
105+
return False
106+
107+
if key.device != query.device or value.device != query.device:
108+
return False
109+
110+
# TileGym kernels currently don't support explicit masks or dropout.
111+
if attn_mask is not None:
112+
return False
113+
114+
try:
115+
dropout_p_val = float(pyval(dropout_p))
116+
except Exception:
117+
return False
118+
if dropout_p_val != 0.0:
119+
return False
120+
121+
is_causal_val = _pybool(is_causal)
122+
123+
# TileGym attention kernels don't implement backward yet.
124+
if query.requires_grad or key.requires_grad or value.requires_grad:
125+
return False
126+
127+
# Expected shapes: (B, H, S, D)
128+
if query.ndim != 4 or key.ndim != 4 or value.ndim != 4:
129+
return False
130+
131+
bq, hq, sq, dq = query.shape
132+
bk, hk, sk, dk = key.shape
133+
bv, hv, sv, dv = value.shape
134+
135+
if bq != bk or bq != bv:
136+
return False
137+
if hq != hk or hq != hv:
138+
# Thunder/torch SDPA expects same number of heads
139+
return False
140+
if sk != sv:
141+
return False
142+
if dq != dk or dq != dv:
143+
# TileGym fmha expects Dq == Dk == Dv
144+
return False
145+
146+
# TileGym decode kernel assumes non-causal semantics for q_len==1 and k_len>1.
147+
if sq == 1 and sk > 1 and is_causal_val:
148+
return False
149+
150+
# TileGym prefill causal assumes query positions start at 0 and align with keys.
151+
if is_causal_val and sq != sk:
152+
return False
153+
154+
# D requirements: TensorCore-friendly.
155+
if dq % 8 != 0:
156+
return False
157+
158+
# Dtype requirements (TileGym kernels use MMA paths).
159+
if query.dtype not in (dtypes.float16, dtypes.bfloat16):
160+
return False
161+
if key.dtype != query.dtype or value.dtype != query.dtype:
162+
return False
163+
164+
# If scale is symbolic/unknown, we can still run (TileGym defaults to 1/sqrt(D)).
165+
_ = _pyfloat_or_none(scale)
166+
167+
return True
168+
169+
170+
def _tilegym_sdpa_impl(
171+
query: torch.Tensor,
172+
key: torch.Tensor,
173+
value: torch.Tensor,
174+
attn_mask: torch.Tensor | None = None,
175+
dropout_p: float = 0.0,
176+
is_causal: bool = False,
177+
*,
178+
scale: float | None = None,
179+
) -> torch.Tensor:
180+
# Checker guarantees attn_mask is None and dropout_p == 0.0.
181+
if query.shape[2] == 1 and key.shape[2] > 1:
182+
# Decode kernel (non-causal semantics expected; checker enforces that)
183+
return tg_ops.fmha_decode(query, key, value, sm_scale=scale)
184+
return tg_ops.fmha(query, key, value, scaling=scale, is_causal=is_causal)
185+
186+
187+
tilegym_sdpa = tilegym_ex.register_operator(
188+
"tilegym_scaled_dot_product_attention",
189+
like=ltorch.scaled_dot_product_attention,
190+
fn=_tilegym_sdpa_impl,
191+
)
192+
193+
tilegym_ex.register_implementation(
194+
ltorch.scaled_dot_product_attention,
195+
op=tilegym_sdpa,
196+
checker=_tilegym_sdpa_checker,
197+
)
198+
199+
200+
def _tilegym_rms_norm_checker(
201+
a: TensorLike,
202+
normalized_shape,
203+
weight: TensorLike | None = None,
204+
eps: float | None = None,
205+
) -> bool:
206+
if not _is_cuda_tensor(a):
207+
return False
208+
209+
if not _tilegym_device_cc_ok(a.device.index):
210+
return False
211+
212+
if weight is None:
213+
# TileGym rms_norm requires affine weight
214+
return False
215+
if not _is_cuda_tensor(weight) or weight.device != a.device:
216+
return False
217+
if a.dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32):
218+
return False
219+
if weight.dtype != a.dtype:
220+
return False
221+
# TileGym rms_norm doesn't implement backward yet.
222+
# We only enable this when the *activation* does not require grad
223+
# (typical inference usage).
224+
if a.requires_grad:
225+
return False
226+
# normalized_shape is validated by the underlying op; keep checker minimal.
227+
return True
228+
229+
230+
def _tilegym_rms_norm_impl(
231+
a: torch.Tensor,
232+
normalized_shape,
233+
weight: torch.Tensor | None = None,
234+
eps: float | None = None,
235+
) -> torch.Tensor:
236+
if eps is None:
237+
eps = torch.finfo(a.dtype).eps if a.dtype.is_floating_point else 0.0
238+
# Checker ensures weight is present.
239+
return tg_ops.rms_norm(a, normalized_shape, weight, eps)
240+
241+
242+
TileGymTransform: Transform | None = None
243+
244+
if hasattr(ltorch, "rms_norm"):
245+
tilegym_rms_norm = tilegym_ex.register_operator(
246+
"tilegym_rms_norm",
247+
like=ltorch.rms_norm,
248+
fn=_tilegym_rms_norm_impl,
249+
)
250+
tilegym_ex.register_implementation(
251+
ltorch.rms_norm,
252+
op=tilegym_rms_norm,
253+
checker=_tilegym_rms_norm_checker,
254+
)

thunder/extend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ def get_all_executors() -> tuple[Executor, ...]:
538538
pythonex,
539539
sdpaex,
540540
fa3ex,
541+
tilegymex,
541542
torch_compile,
542543
torchex,
543544
transformer_engineex,
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import pytest
2+
import torch
3+
4+
import thunder
5+
from lightning_utilities.core.imports import package_available
6+
7+
8+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
9+
@pytest.mark.skipif(not package_available("tilegym"), reason="requires tilegym")
10+
def test_tilegym_executor_sdpa_rewrites_and_runs():
11+
tilegym_ex = thunder.get_executor("tilegym")
12+
assert tilegym_ex is not None
13+
14+
def fn(q, k, v):
15+
return torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
16+
17+
# Choose a shape that avoids other SDPA executors' restrictions interfering with this test:
18+
# - Head dim divisible by 8
19+
# - No explicit attn_mask, no dropout
20+
B, H, S, D = 2, 8, 256, 128
21+
q = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
22+
k = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
23+
v = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
24+
25+
jfn = thunder.jit(fn, executors=(tilegym_ex, *thunder.get_default_executors()))
26+
out = jfn(q, k, v)
27+
ref = fn(q, k, v)
28+
29+
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
30+
31+
trace = thunder.last_traces(jfn)[-1]
32+
assert any(bsym.sym.executor is tilegym_ex for bsym in trace.bound_symbols)
33+
34+
35+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
36+
@pytest.mark.skipif(not package_available("tilegym"), reason="requires tilegym")
37+
def test_tilegym_executor_rms_norm_rewrites_and_runs():
38+
tilegym_ex = thunder.get_executor("tilegym")
39+
assert tilegym_ex is not None
40+
41+
def fn(x, w):
42+
return torch.nn.functional.rms_norm(x, (x.shape[-1],), w, 1e-6)
43+
44+
x = torch.randn(4, 128, device="cuda", dtype=torch.bfloat16, requires_grad=False)
45+
w = torch.randn(128, device="cuda", dtype=torch.bfloat16, requires_grad=False)
46+
47+
jfn = thunder.jit(fn, executors=(tilegym_ex, *thunder.get_default_executors()))
48+
out = jfn(x, w)
49+
ref = fn(x, w)
50+
51+
torch.testing.assert_close(out, ref, atol=0, rtol=0)
52+
53+
trace = thunder.last_traces(jfn)[-1]
54+
assert any(bsym.sym.executor is tilegym_ex for bsym in trace.bound_symbols)

0 commit comments

Comments
 (0)