Skip to content

Commit f6d16a1

Browse files
committed
Add TLX attention (WS pipelined pingpong hopper)
1 parent 4bc0d04 commit f6d16a1

File tree

2 files changed

+337
-1
lines changed

2 files changed

+337
-1
lines changed

tritonbench/operators/flash_attention/operator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@
6161
from tritonbench.utils.path_utils import add_ld_library_path
6262
from tritonbench.utils.triton_op import is_fbcode
6363

64+
from .tlx_attn_ws_pipelined_pingpong_hopper import (
65+
attention as tlx_attn_ws_pipelined_pingpong_hopper,
66+
)
67+
6468

6569
# [Optional] flash_attn v2
6670
try:
@@ -136,6 +140,14 @@
136140
except (ImportError, IOError, AttributeError, TypeError):
137141
HAS_XFORMERS = False
138142

143+
# [Optional] TLX backend
144+
try:
145+
import triton.language.extra.tlx as tlx
146+
147+
HAS_TLX = True
148+
except (ImportError, IOError, AttributeError):
149+
HAS_TLX = False
150+
139151
from typing import Any, Generator, List
140152

141153
from tritonbench.utils.input import input_filter
@@ -299,6 +311,18 @@ def triton_tutorial_flash_v2_tma(
299311
q, k, v, self.causal, self.sm_scale, "tma"
300312
)
301313

314+
@register_benchmark(enabled=HAS_TLX)
315+
def tlx_attn_ws_pipelined_pingpong_hopper(
316+
self,
317+
q: torch.Tensor,
318+
k: torch.Tensor,
319+
v: torch.Tensor,
320+
) -> Callable:
321+
# TLX flash attention with Hopper optimizations
322+
return lambda: tlx_attn_ws_pipelined_pingpong_hopper(
323+
q, k, v, self.sm_scale
324+
)
325+
302326
def xformers_preprocess(
303327
self,
304328
q: torch.Tensor,
@@ -341,7 +365,7 @@ def xformers_splitk(
341365
fhma_input, needs_gradient=need_gradient
342366
)
343367

344-
@register_benchmark(enabled=False, label=f"cudnn-{torch.backends.cudnn.version()}")
368+
@register_benchmark(enabled=False) # , label=f"cudnn-{torch.backends.cudnn.version()}")
345369
def cudnn(self, q, k, v):
346370
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
347371

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
import triton.language.extra.tlx as tlx
6+
from triton.tools.tensor_descriptor import TensorDescriptor
7+
8+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
9+
10+
11+
def _host_descriptor_pre_hook(nargs):
12+
BLOCK_M = nargs["BLOCK_M"]
13+
BLOCK_N = nargs["BLOCK_N"]
14+
HEAD_DIM = nargs["HEAD_DIM"]
15+
if not isinstance(nargs["desc_q"], TensorDescriptor):
16+
return
17+
HEAD_DIM = nargs["HEAD_DIM"]
18+
NUM_MMA_GROUPS = nargs["NUM_MMA_GROUPS"]
19+
BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS
20+
nargs["desc_q"].block_shape = [BLOCK_M_SPLIT, HEAD_DIM]
21+
if nargs["FP8_OUTPUT"]:
22+
nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N]
23+
else:
24+
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
25+
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
26+
nargs["desc_o"].block_shape = [BLOCK_M_SPLIT, HEAD_DIM]
27+
28+
29+
configs = [
30+
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'NUM_BUFFERS': 2, 'NUM_MMA_WARPS': 8, 'NUM_MMA_GROUPS': 2},
31+
num_stages=0, num_warps=4, pre_hook=_host_descriptor_pre_hook),
32+
]
33+
34+
35+
@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"])
36+
@triton.jit
37+
def _attn_fwd_ws_pipelined_pingpong(sm_scale, M, #
38+
Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, #
39+
HEAD_DIM: tl.constexpr, #
40+
BLOCK_M: tl.constexpr, #
41+
BLOCK_N: tl.constexpr, #
42+
FP8_OUTPUT: tl.constexpr, #
43+
NUM_BUFFERS: tl.constexpr, #
44+
NUM_MMA_WARPS: tl.constexpr, #
45+
NUM_MMA_GROUPS: tl.constexpr, #
46+
):
47+
tl.static_assert(BLOCK_N <= HEAD_DIM)
48+
BLOCK_M_SPLIT: tl.constexpr = BLOCK_M // NUM_MMA_GROUPS
49+
50+
# allocate buffers
51+
q_tiles = tlx.local_alloc((BLOCK_M_SPLIT, HEAD_DIM), tlx.dtype_of(desc_q), NUM_MMA_GROUPS)
52+
k_tiles = tlx.local_alloc((BLOCK_N, HEAD_DIM), tlx.dtype_of(desc_k), NUM_BUFFERS)
53+
v_tiles = tlx.local_alloc((BLOCK_N, HEAD_DIM), tlx.dtype_of(desc_v), NUM_BUFFERS)
54+
55+
# allocate barriers
56+
q_fulls = tlx.alloc_barriers(num_barriers=NUM_MMA_GROUPS, arrive_count=1)
57+
k_empties = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS)
58+
k_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1)
59+
v_empties = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS)
60+
v_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1)
61+
62+
with tlx.async_tasks():
63+
# producer group
64+
with tlx.async_task("default"):
65+
# initialize offsets
66+
start_m = tl.program_id(0)
67+
off_hz = tl.program_id(1)
68+
off_z = off_hz // H
69+
off_h = off_hz % H
70+
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
71+
qo_offset_y = offset_y + start_m * BLOCK_M
72+
lo, hi = 0, N_CTX
73+
kv_offset_y = offset_y + lo
74+
75+
# load q: it will stay in SRAM throughout
76+
for cid in tl.range(0, NUM_MMA_GROUPS, loop_unroll_factor=NUM_MMA_GROUPS):
77+
q_full = tlx.local_view(q_fulls, cid)
78+
tlx.barrier_expect_bytes(q_full, 2 * BLOCK_M_SPLIT * HEAD_DIM) # float16
79+
q_tile = tlx.local_view(q_tiles, cid)
80+
qo_offset_y_split = qo_offset_y + cid * BLOCK_M_SPLIT
81+
tlx.async_descriptor_load(desc_q, q_tile, [qo_offset_y_split, 0], q_full)
82+
83+
# loop over loading k, v
84+
kv_phase = 0
85+
acc_cnt = 0
86+
for _ in tl.range(lo, hi, BLOCK_N):
87+
buf_id = acc_cnt % NUM_BUFFERS
88+
# buffers in a row share the same phase
89+
kv_phase = kv_phase ^ (buf_id == 0)
90+
91+
# wait for the K buffer to be released by the consumer
92+
k_empty = tlx.local_view(k_empties, buf_id)
93+
tlx.barrier_wait(k_empty, kv_phase)
94+
# load K
95+
k_full = tlx.local_view(k_fulls, buf_id)
96+
k_tile = tlx.local_view(k_tiles, buf_id)
97+
tlx.barrier_expect_bytes(k_full, 2 * BLOCK_N * HEAD_DIM) # float16
98+
tlx.async_descriptor_load(desc_k, k_tile, [kv_offset_y, 0], k_full)
99+
100+
# wait for the V buffer to be released by the consumer
101+
v_empty = tlx.local_view(v_empties, buf_id)
102+
tlx.barrier_wait(v_empty, kv_phase)
103+
# load V
104+
v_full = tlx.local_view(v_fulls, buf_id)
105+
v_tile = tlx.local_view(v_tiles, buf_id)
106+
tlx.barrier_expect_bytes(v_full, 2 * BLOCK_N * HEAD_DIM) # float16
107+
tlx.async_descriptor_load(desc_v, v_tile, [kv_offset_y, 0], v_full)
108+
109+
kv_offset_y += BLOCK_N
110+
acc_cnt += 1
111+
112+
# consumer group
113+
with tlx.async_task(num_warps=NUM_MMA_WARPS // NUM_MMA_GROUPS, registers=232, replicate=NUM_MMA_GROUPS):
114+
# initialize pointer to m and l
115+
m_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) - float("inf")
116+
l_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) + 1.0
117+
acc = tl.zeros([BLOCK_M_SPLIT, HEAD_DIM], dtype=tl.float32)
118+
119+
# load scales
120+
qk_scale = sm_scale
121+
qk_scale *= 1.44269504 # 1/log(2)
122+
123+
# wait for the Q buffer to be populated by the producer
124+
cid: tl.constexpr = tlx.async_task_replica_id()
125+
q_full = tlx.local_view(q_fulls, cid)
126+
tlx.barrier_wait(q_full, 0)
127+
q_tile = tlx.local_view(q_tiles, cid)
128+
129+
lo, hi = 0, N_CTX
130+
k_phase = 0
131+
v_phase = 1
132+
k_buf_id = 0
133+
v_buf_id = 0
134+
135+
# wait for the K[0] buffer to be populated by the producer
136+
k_full = tlx.local_view(k_fulls, k_buf_id)
137+
tlx.barrier_wait(k_full, k_phase)
138+
k_tile = tlx.local_view(k_tiles, k_buf_id)
139+
140+
# -- compute qk[0] ----
141+
k_tile = tlx.local_trans(k_tile)
142+
143+
if cid == 0:
144+
# Consumer 0 waits for Consumer 1 to reach synchronization point at barrier 9.
145+
tlx.named_barrier_wait(9, 256)
146+
else:
147+
# Consumer 1 signals its arrival at barrier 9.
148+
tlx.named_barrier_arrive(9, 256)
149+
# Then waits at barrier 10 until Consumer 0 finishes issuing its async_dot.
150+
tlx.named_barrier_wait(10, 256)
151+
152+
qk = tlx.async_dot(q_tile, k_tile)
153+
154+
if cid == 0:
155+
# After issuing async_dot, Consumer 0 signals barrier 10 to unblock Consumer 1.
156+
tlx.named_barrier_arrive(10, 256)
157+
158+
# wait for the MMA using to complete
159+
qk = tlx.async_dot_wait(0, qk)
160+
# release the K buffer
161+
k_empty = tlx.local_view(k_empties, k_buf_id)
162+
tlx.barrier_arrive(k_empty, 1)
163+
164+
# -- compute m_i and l_i ----
165+
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
166+
qk = qk * qk_scale - m_ij[:, None]
167+
p = tl.math.exp2(qk)
168+
# -- compute correction factor
169+
alpha = tl.math.exp2(m_i - m_ij)
170+
# -- update output accumulator[0] --
171+
acc = acc * alpha[:, None]
172+
l_ij = tl.sum(p, 1)
173+
l_i = l_i * alpha + l_ij
174+
m_i = m_ij
175+
acc_cnt = 1
176+
177+
178+
# loop over k, v and update accumulator
179+
for _ in tl.range(lo + BLOCK_N, hi, BLOCK_N):
180+
k_buf_id = acc_cnt % NUM_BUFFERS
181+
# buffers in a row share the same phase
182+
k_phase = k_phase ^ (k_buf_id == 0)
183+
184+
# wait for the K buffer to be populated by the producer
185+
k_full = tlx.local_view(k_fulls, k_buf_id)
186+
tlx.barrier_wait(k_full, k_phase)
187+
k_tile = tlx.local_view(k_tiles, k_buf_id)
188+
189+
# compute qk for the current iteration
190+
k_tile = tlx.local_trans(k_tile)
191+
qk = tlx.async_dot(q_tile, k_tile)
192+
193+
# compute pv from the previous iteration
194+
# wait for the previous V buffer to be populated by the producer
195+
v_buf_id = (acc_cnt - 1) % NUM_BUFFERS
196+
v_phase = v_phase ^ (v_buf_id == 0)
197+
v_full = tlx.local_view(v_fulls, v_buf_id)
198+
tlx.barrier_wait(v_full, v_phase)
199+
v_tile = tlx.local_view(v_tiles, v_buf_id)
200+
# prepare p and v for the dot
201+
p = p.to(tlx.dtype_of(desc_k))
202+
acc = tlx.async_dot(p, v_tile, acc)
203+
204+
# wait for the current qk MMA to complete
205+
qk = tlx.async_dot_wait(1, qk)
206+
# release the K buffer
207+
k_empty = tlx.local_view(k_empties, k_buf_id)
208+
tlx.barrier_arrive(k_empty, 1)
209+
210+
# -- compute m_i and l_i ----
211+
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
212+
qk = qk * qk_scale - m_ij[:, None]
213+
p = tl.math.exp2(qk)
214+
# -- compute correction factor
215+
alpha = tl.math.exp2(m_i - m_ij)
216+
l_ij = tl.sum(p, 1)
217+
# update m_i and l_i
218+
l_i = l_i * alpha + l_ij
219+
m_i = m_ij
220+
221+
# -- update output accumulator --
222+
# wait for the previous pv MMA to complete
223+
acc = tlx.async_dot_wait(0, acc)
224+
# release the V buffer
225+
v_empty = tlx.local_view(v_empties, v_buf_id)
226+
tlx.barrier_arrive(v_empty, 1)
227+
acc = acc * alpha[:, None]
228+
acc_cnt += 1
229+
230+
# compute pv from the last iteration
231+
# wait for the V buffer to be populated by the producer
232+
v_buf_id = (acc_cnt - 1) % NUM_BUFFERS
233+
v_phase = v_phase ^ (v_buf_id == 0)
234+
v_full = tlx.local_view(v_fulls, v_buf_id)
235+
tlx.barrier_wait(v_full, v_phase)
236+
v_tile = tlx.local_view(v_tiles, v_buf_id)
237+
# prepare p and v for the dot
238+
p = p.to(tlx.dtype_of(desc_k))
239+
acc = tlx.async_dot(p, v_tile, acc)
240+
# wait for the MMA using to complete
241+
acc = tlx.async_dot_wait(0, acc)
242+
# release the V buffer
243+
v_empty = tlx.local_view(v_empties, v_buf_id)
244+
tlx.barrier_arrive(v_empty, 1)
245+
246+
# epilogue
247+
start_m = tl.program_id(0)
248+
off_hz = tl.program_id(1)
249+
off_z = off_hz // H
250+
off_h = off_hz % H
251+
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
252+
qo_offset_y = offset_y + start_m * BLOCK_M
253+
qo_offset_y_split = qo_offset_y + cid * BLOCK_M_SPLIT
254+
m_i += tl.math.log2(l_i)
255+
acc = acc / l_i[:, None]
256+
offs_m = start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT)
257+
m_ptrs = M + off_hz * N_CTX + offs_m
258+
tl.store(m_ptrs, m_i)
259+
desc_o.store([qo_offset_y_split, 0], acc.to(tlx.dtype_of(desc_o)))
260+
261+
262+
class _attention(torch.autograd.Function):
263+
264+
@staticmethod
265+
def forward(ctx, q, k, v, sm_scale):
266+
# shape constraints
267+
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
268+
# when v is in float8_e5m2 it is transposed.
269+
HEAD_DIM_V = v.shape[-1]
270+
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
271+
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
272+
o = torch.empty_like(q)
273+
extra_kern_args = {}
274+
275+
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
276+
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
277+
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
278+
279+
dummy_block = [1, 1]
280+
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
281+
if q.dtype == torch.float8_e5m2:
282+
desc_v = TensorDescriptor(v, shape=[HEAD_DIM_K, y_dim], strides=[q.shape[2], 1], block_shape=dummy_block)
283+
else:
284+
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
285+
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
286+
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
287+
288+
def alloc_fn(size: int, align: int, _):
289+
return torch.empty(size, dtype=torch.int8, device="cuda")
290+
291+
triton.set_allocator(alloc_fn)
292+
293+
def grid(META):
294+
return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
295+
296+
ctx.grid = grid
297+
_attn_fwd_ws_pipelined_pingpong[grid](
298+
sm_scale, M, #
299+
q.shape[0], q.shape[1], #
300+
desc_q, desc_k, desc_v, desc_o, #
301+
N_CTX=q.shape[2], #
302+
HEAD_DIM=HEAD_DIM_K, #
303+
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
304+
**extra_kern_args)
305+
306+
ctx.save_for_backward(q, k, v, o, M)
307+
ctx.sm_scale = sm_scale
308+
ctx.HEAD_DIM = HEAD_DIM_K
309+
return o
310+
311+
312+
attention = _attention.apply

0 commit comments

Comments
 (0)