|
136 | 136 | except (ImportError, IOError, AttributeError, TypeError):
|
137 | 137 | HAS_XFORMERS = False
|
138 | 138 |
|
| 139 | +# [Optional] TLX backend |
| 140 | +try: |
| 141 | + import triton.language.extra.tlx as tlx |
| 142 | + |
| 143 | + from .tlx_attn_ws_pipelined_pingpong_hopper import ( |
| 144 | + attention as tlx_attn_ws_pipelined_pingpong_hopper, |
| 145 | + ) |
| 146 | + |
| 147 | + HAS_TLX = True |
| 148 | +except (ImportError, IOError, AttributeError): |
| 149 | + HAS_TLX = False |
| 150 | + |
139 | 151 | from typing import Any, Generator, List
|
140 | 152 |
|
141 | 153 | from tritonbench.utils.input import input_filter
|
@@ -299,6 +311,16 @@ def triton_tutorial_flash_v2_tma(
|
299 | 311 | q, k, v, self.causal, self.sm_scale, "tma"
|
300 | 312 | )
|
301 | 313 |
|
| 314 | + @register_benchmark(enabled=HAS_TLX) |
| 315 | + def tlx_attn_ws_pipelined_pingpong_hopper( |
| 316 | + self, |
| 317 | + q: torch.Tensor, |
| 318 | + k: torch.Tensor, |
| 319 | + v: torch.Tensor, |
| 320 | + ) -> Callable: |
| 321 | + # TLX flash attention with Hopper optimizations |
| 322 | + return lambda: tlx_attn_ws_pipelined_pingpong_hopper(q, k, v, self.sm_scale) |
| 323 | + |
302 | 324 | def xformers_preprocess(
|
303 | 325 | self,
|
304 | 326 | q: torch.Tensor,
|
@@ -341,7 +363,9 @@ def xformers_splitk(
|
341 | 363 | fhma_input, needs_gradient=need_gradient
|
342 | 364 | )
|
343 | 365 |
|
344 |
| - @register_benchmark(enabled=False, label=f"cudnn-{torch.backends.cudnn.version()}") |
| 366 | + @register_benchmark( |
| 367 | + enabled=False |
| 368 | + ) # , label=f"cudnn-{torch.backends.cudnn.version()}") |
345 | 369 | def cudnn(self, q, k, v):
|
346 | 370 | os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
|
347 | 371 |
|
|
0 commit comments