Skip to content

Commit 6d5389b

Browse files
committed
Add TLX attention (WS pipelined pingpong hopper)
1 parent 4bc0d04 commit 6d5389b

File tree

2 files changed

+405
-0
lines changed

2 files changed

+405
-0
lines changed

tritonbench/operators/flash_attention/operator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,18 @@
136136
except (ImportError, IOError, AttributeError, TypeError):
137137
HAS_XFORMERS = False
138138

139+
# [Optional] TLX backend
140+
try:
141+
import triton.language.extra.tlx as tlx
142+
143+
from .tlx_attn_ws_pipelined_pingpong_hopper import (
144+
attention as tlx_attn_ws_pipelined_pingpong_hopper,
145+
)
146+
147+
HAS_TLX = True
148+
except (ImportError, IOError, AttributeError):
149+
HAS_TLX = False
150+
139151
from typing import Any, Generator, List
140152

141153
from tritonbench.utils.input import input_filter
@@ -299,6 +311,16 @@ def triton_tutorial_flash_v2_tma(
299311
q, k, v, self.causal, self.sm_scale, "tma"
300312
)
301313

314+
@register_benchmark(enabled=HAS_TLX)
315+
def tlx_attn_ws_pipelined_pingpong_hopper(
316+
self,
317+
q: torch.Tensor,
318+
k: torch.Tensor,
319+
v: torch.Tensor,
320+
) -> Callable:
321+
# TLX flash attention with Hopper optimizations
322+
return lambda: tlx_attn_ws_pipelined_pingpong_hopper(q, k, v, self.sm_scale)
323+
302324
def xformers_preprocess(
303325
self,
304326
q: torch.Tensor,

0 commit comments

Comments
 (0)