Skip to content

Commit 578397f

Browse files
committed
Add TLX attention (WS pipelined pingpong hopper)
1 parent 4bc0d04 commit 578397f

File tree

2 files changed

+552
-0
lines changed

2 files changed

+552
-0
lines changed

tritonbench/operators/flash_attention/operator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@
6161
from tritonbench.utils.path_utils import add_ld_library_path
6262
from tritonbench.utils.triton_op import is_fbcode
6363

64+
from .tlx_attn_ws_pipelined_pingpong_hopper import (
65+
attention as tlx_attn_ws_pipelined_pingpong_hopper,
66+
)
67+
6468

6569
# [Optional] flash_attn v2
6670
try:
@@ -136,6 +140,14 @@
136140
except (ImportError, IOError, AttributeError, TypeError):
137141
HAS_XFORMERS = False
138142

143+
# [Optional] TLX backend
144+
try:
145+
import triton.language.extra.tlx as tlx
146+
147+
HAS_TLX = True
148+
except (ImportError, IOError, AttributeError):
149+
HAS_TLX = False
150+
139151
from typing import Any, Generator, List
140152

141153
from tritonbench.utils.input import input_filter
@@ -299,6 +311,16 @@ def triton_tutorial_flash_v2_tma(
299311
q, k, v, self.causal, self.sm_scale, "tma"
300312
)
301313

314+
@register_benchmark(enabled=HAS_TLX)
315+
def tlx_attn_ws_pipelined_pingpong_hopper(
316+
self,
317+
q: torch.Tensor,
318+
k: torch.Tensor,
319+
v: torch.Tensor,
320+
) -> Callable:
321+
# TLX flash attention with Hopper optimizations
322+
return lambda: tlx_attn_ws_pipelined_pingpong_hopper(q, k, v, self.sm_scale)
323+
302324
def xformers_preprocess(
303325
self,
304326
q: torch.Tensor,

0 commit comments

Comments
 (0)