Skip to content

Commit 9e4ffe6

Browse files
committed
Add TLX attention (WS pipelined pingpong hopper)
1 parent 4bc0d04 commit 9e4ffe6

File tree

2 files changed

+406
-1
lines changed

2 files changed

+406
-1
lines changed

tritonbench/operators/flash_attention/operator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,18 @@
136136
except (ImportError, IOError, AttributeError, TypeError):
137137
HAS_XFORMERS = False
138138

139+
# [Optional] TLX backend
140+
try:
141+
import triton.language.extra.tlx as tlx
142+
143+
from .tlx_attn_ws_pipelined_pingpong_hopper import (
144+
attention as tlx_attn_ws_pipelined_pingpong_hopper,
145+
)
146+
147+
HAS_TLX = True
148+
except (ImportError, IOError, AttributeError):
149+
HAS_TLX = False
150+
139151
from typing import Any, Generator, List
140152

141153
from tritonbench.utils.input import input_filter
@@ -299,6 +311,16 @@ def triton_tutorial_flash_v2_tma(
299311
q, k, v, self.causal, self.sm_scale, "tma"
300312
)
301313

314+
@register_benchmark(enabled=HAS_TLX)
315+
def tlx_attn_ws_pipelined_pingpong_hopper(
316+
self,
317+
q: torch.Tensor,
318+
k: torch.Tensor,
319+
v: torch.Tensor,
320+
) -> Callable:
321+
# TLX flash attention with Hopper optimizations
322+
return lambda: tlx_attn_ws_pipelined_pingpong_hopper(q, k, v, self.sm_scale)
323+
302324
def xformers_preprocess(
303325
self,
304326
q: torch.Tensor,
@@ -341,7 +363,9 @@ def xformers_splitk(
341363
fhma_input, needs_gradient=need_gradient
342364
)
343365

344-
@register_benchmark(enabled=False, label=f"cudnn-{torch.backends.cudnn.version()}")
366+
@register_benchmark(
367+
enabled=False
368+
) # , label=f"cudnn-{torch.backends.cudnn.version()}")
345369
def cudnn(self, q, k, v):
346370
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
347371

0 commit comments

Comments
 (0)