|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | +from typing import TYPE_CHECKING |
| 5 | + |
| 6 | +import torch |
| 7 | +from lightning_utilities.core.imports import package_available |
| 8 | + |
| 9 | +import thunder.core.devices as devices |
| 10 | +import thunder.core.dtypes as dtypes |
| 11 | +from thunder.core.proxies import pyval |
| 12 | +from thunder.extend import OperatorExecutor, register_executor |
| 13 | +from thunder import Transform |
| 14 | +import thunder.torch as ltorch |
| 15 | + |
| 16 | +if TYPE_CHECKING: |
| 17 | + from thunder.torch import TensorLike |
| 18 | + |
| 19 | + |
| 20 | +if not package_available("tilegym"): |
| 21 | + raise ImportError("tilegym is required for the tilegym executor") |
| 22 | + |
| 23 | +import tilegym |
| 24 | +from tilegym import ops as tg_ops |
| 25 | + |
| 26 | + |
| 27 | +tilegym_ex: OperatorExecutor = OperatorExecutor("tilegym", version=getattr(tilegym, "__version__", None)) |
| 28 | +register_executor(tilegym_ex) |
| 29 | + |
| 30 | + |
| 31 | +def _is_cuda_tensor(t: TensorLike) -> bool: |
| 32 | + return t.device.devicetype == devices.DeviceType.CUDA |
| 33 | + |
| 34 | + |
| 35 | +def _pybool(x) -> bool: |
| 36 | + try: |
| 37 | + return bool(pyval(x)) |
| 38 | + except Exception: |
| 39 | + return False |
| 40 | + |
| 41 | + |
| 42 | +def _pyfloat_or_none(x) -> float | None: |
| 43 | + if x is None: |
| 44 | + return None |
| 45 | + try: |
| 46 | + return float(pyval(x)) |
| 47 | + except Exception: |
| 48 | + return None |
| 49 | + |
| 50 | + |
| 51 | +def _parse_min_cc(s: str) -> tuple[int, int] | None: |
| 52 | + # Accept "10.0", "10,0", or "100" (treated as "10.0"). |
| 53 | + s = (s or "").strip() |
| 54 | + if not s: |
| 55 | + return None |
| 56 | + if "." in s: |
| 57 | + a, b = s.split(".", 1) |
| 58 | + return int(a), int(b) |
| 59 | + if "," in s: |
| 60 | + a, b = s.split(",", 1) |
| 61 | + return int(a), int(b) |
| 62 | + if s.isdigit(): |
| 63 | + if len(s) >= 2: |
| 64 | + return int(s[:-1]), int(s[-1]) |
| 65 | + return int(s), 0 |
| 66 | + return None |
| 67 | + |
| 68 | + |
| 69 | +def _tilegym_device_cc_ok(device_index: int) -> bool: |
| 70 | + # Default to Blackwell+ (SM100). Override via env vars: |
| 71 | + # - THUNDER_TILEGYM_ALLOW_ANY_CC=1 (bypass) |
| 72 | + # - THUNDER_TILEGYM_MIN_CC=10.0 (set minimum) |
| 73 | + if os.environ.get("THUNDER_TILEGYM_ALLOW_ANY_CC", "0").lower() in ("1", "true", "yes", "y", "on"): |
| 74 | + return True |
| 75 | + |
| 76 | + min_cc = _parse_min_cc(os.environ.get("THUNDER_TILEGYM_MIN_CC", "10.0")) |
| 77 | + if min_cc is None: |
| 78 | + min_cc = (10, 0) |
| 79 | + |
| 80 | + if not torch.cuda.is_available(): |
| 81 | + return False |
| 82 | + try: |
| 83 | + cc = torch.cuda.get_device_capability(device_index) |
| 84 | + except Exception: |
| 85 | + return False |
| 86 | + |
| 87 | + return tuple(cc) >= tuple(min_cc) |
| 88 | + |
| 89 | + |
| 90 | +def _tilegym_sdpa_checker( |
| 91 | + query: TensorLike, |
| 92 | + key: TensorLike, |
| 93 | + value: TensorLike, |
| 94 | + attn_mask: TensorLike | None = None, |
| 95 | + dropout_p: float = 0.0, |
| 96 | + is_causal: bool = False, |
| 97 | + *, |
| 98 | + scale: float | None = None, |
| 99 | +) -> bool: |
| 100 | + # TileGym kernels are CUDA-only. |
| 101 | + if not (_is_cuda_tensor(query) and _is_cuda_tensor(key) and _is_cuda_tensor(value)): |
| 102 | + return False |
| 103 | + |
| 104 | + if not _tilegym_device_cc_ok(query.device.index): |
| 105 | + return False |
| 106 | + |
| 107 | + if key.device != query.device or value.device != query.device: |
| 108 | + return False |
| 109 | + |
| 110 | + # TileGym kernels currently don't support explicit masks or dropout. |
| 111 | + if attn_mask is not None: |
| 112 | + return False |
| 113 | + |
| 114 | + try: |
| 115 | + dropout_p_val = float(pyval(dropout_p)) |
| 116 | + except Exception: |
| 117 | + return False |
| 118 | + if dropout_p_val != 0.0: |
| 119 | + return False |
| 120 | + |
| 121 | + is_causal_val = _pybool(is_causal) |
| 122 | + |
| 123 | + # TileGym attention kernels don't implement backward yet. |
| 124 | + if query.requires_grad or key.requires_grad or value.requires_grad: |
| 125 | + return False |
| 126 | + |
| 127 | + # Expected shapes: (B, H, S, D) |
| 128 | + if query.ndim != 4 or key.ndim != 4 or value.ndim != 4: |
| 129 | + return False |
| 130 | + |
| 131 | + bq, hq, sq, dq = query.shape |
| 132 | + bk, hk, sk, dk = key.shape |
| 133 | + bv, hv, sv, dv = value.shape |
| 134 | + |
| 135 | + if bq != bk or bq != bv: |
| 136 | + return False |
| 137 | + if hq != hk or hq != hv: |
| 138 | + # Thunder/torch SDPA expects same number of heads |
| 139 | + return False |
| 140 | + if sk != sv: |
| 141 | + return False |
| 142 | + if dq != dk or dq != dv: |
| 143 | + # TileGym fmha expects Dq == Dk == Dv |
| 144 | + return False |
| 145 | + |
| 146 | + # TileGym decode kernel assumes non-causal semantics for q_len==1 and k_len>1. |
| 147 | + if sq == 1 and sk > 1 and is_causal_val: |
| 148 | + return False |
| 149 | + |
| 150 | + # TileGym prefill causal assumes query positions start at 0 and align with keys. |
| 151 | + if is_causal_val and sq != sk: |
| 152 | + return False |
| 153 | + |
| 154 | + # D requirements: TensorCore-friendly. |
| 155 | + if dq % 8 != 0: |
| 156 | + return False |
| 157 | + |
| 158 | + # Dtype requirements (TileGym kernels use MMA paths). |
| 159 | + if query.dtype not in (dtypes.float16, dtypes.bfloat16): |
| 160 | + return False |
| 161 | + if key.dtype != query.dtype or value.dtype != query.dtype: |
| 162 | + return False |
| 163 | + |
| 164 | + # If scale is symbolic/unknown, we can still run (TileGym defaults to 1/sqrt(D)). |
| 165 | + _ = _pyfloat_or_none(scale) |
| 166 | + |
| 167 | + return True |
| 168 | + |
| 169 | + |
| 170 | +def _tilegym_sdpa_impl( |
| 171 | + query: torch.Tensor, |
| 172 | + key: torch.Tensor, |
| 173 | + value: torch.Tensor, |
| 174 | + attn_mask: torch.Tensor | None = None, |
| 175 | + dropout_p: float = 0.0, |
| 176 | + is_causal: bool = False, |
| 177 | + *, |
| 178 | + scale: float | None = None, |
| 179 | +) -> torch.Tensor: |
| 180 | + # Checker guarantees attn_mask is None and dropout_p == 0.0. |
| 181 | + if query.shape[2] == 1 and key.shape[2] > 1: |
| 182 | + # Decode kernel (non-causal semantics expected; checker enforces that) |
| 183 | + return tg_ops.fmha_decode(query, key, value, sm_scale=scale) |
| 184 | + return tg_ops.fmha(query, key, value, scaling=scale, is_causal=is_causal) |
| 185 | + |
| 186 | + |
| 187 | +tilegym_sdpa = tilegym_ex.register_operator( |
| 188 | + "tilegym_scaled_dot_product_attention", |
| 189 | + like=ltorch.scaled_dot_product_attention, |
| 190 | + fn=_tilegym_sdpa_impl, |
| 191 | +) |
| 192 | + |
| 193 | +tilegym_ex.register_implementation( |
| 194 | + ltorch.scaled_dot_product_attention, |
| 195 | + op=tilegym_sdpa, |
| 196 | + checker=_tilegym_sdpa_checker, |
| 197 | +) |
| 198 | + |
| 199 | + |
| 200 | +def _tilegym_rms_norm_checker( |
| 201 | + a: TensorLike, |
| 202 | + normalized_shape, |
| 203 | + weight: TensorLike | None = None, |
| 204 | + eps: float | None = None, |
| 205 | +) -> bool: |
| 206 | + if not _is_cuda_tensor(a): |
| 207 | + return False |
| 208 | + |
| 209 | + if not _tilegym_device_cc_ok(a.device.index): |
| 210 | + return False |
| 211 | + |
| 212 | + if weight is None: |
| 213 | + # TileGym rms_norm requires affine weight |
| 214 | + return False |
| 215 | + if not _is_cuda_tensor(weight) or weight.device != a.device: |
| 216 | + return False |
| 217 | + if a.dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32): |
| 218 | + return False |
| 219 | + if weight.dtype != a.dtype: |
| 220 | + return False |
| 221 | + # TileGym rms_norm doesn't implement backward yet. |
| 222 | + # We only enable this when the *activation* does not require grad |
| 223 | + # (typical inference usage). |
| 224 | + if a.requires_grad: |
| 225 | + return False |
| 226 | + # normalized_shape is validated by the underlying op; keep checker minimal. |
| 227 | + return True |
| 228 | + |
| 229 | + |
| 230 | +def _tilegym_rms_norm_impl( |
| 231 | + a: torch.Tensor, |
| 232 | + normalized_shape, |
| 233 | + weight: torch.Tensor | None = None, |
| 234 | + eps: float | None = None, |
| 235 | +) -> torch.Tensor: |
| 236 | + if eps is None: |
| 237 | + eps = torch.finfo(a.dtype).eps if a.dtype.is_floating_point else 0.0 |
| 238 | + # Checker ensures weight is present. |
| 239 | + return tg_ops.rms_norm(a, normalized_shape, weight, eps) |
| 240 | + |
| 241 | + |
| 242 | +TileGymTransform: Transform | None = None |
| 243 | + |
| 244 | +if hasattr(ltorch, "rms_norm"): |
| 245 | + tilegym_rms_norm = tilegym_ex.register_operator( |
| 246 | + "tilegym_rms_norm", |
| 247 | + like=ltorch.rms_norm, |
| 248 | + fn=_tilegym_rms_norm_impl, |
| 249 | + ) |
| 250 | + tilegym_ex.register_implementation( |
| 251 | + ltorch.rms_norm, |
| 252 | + op=tilegym_rms_norm, |
| 253 | + checker=_tilegym_rms_norm_checker, |
| 254 | + ) |
0 commit comments