Skip to content

Commit d0e6fa5

Browse files
htyumeta-codesync[bot]
authored andcommitted
[TLX] Pipelining flash attention bwd kernel (#630)
Summary: Key changes - Pipeline the load and mma task - Use 8 warps instead of 4 for the compute task - Adjust buffer stages and register requirements Before ``` fused-attention-ws-pipelined-persistent-batch4-head32-d128: N_CTX Triton [FP16] 0 1024.0 322.140956 1 2048.0 394.705387 2 4096.0 440.500436 3 8192.0 460.003597 4 16384.0 472.748830 ``` After ``` fused-attention-ws-pipelined-persistent-batch4-head32-d128: N_CTX Triton [FP16] 0 1024.0 433.543387 1 2048.0 550.770631 2 4096.0 598.717550 3 8192.0 643.836015 4 16384.0 655.650778 ``` Pull Request resolved: #630 Reviewed By: manman-ren Differential Revision: D86162895 Pulled By: htyu fbshipit-source-id: 9e87337a4af5880657322537c3290596065ab33f
1 parent f36e7fc commit d0e6fa5

File tree

1 file changed

+158
-48
lines changed

1 file changed

+158
-48
lines changed

third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py

Lines changed: 158 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -824,11 +824,11 @@ def _bwd_host_descriptor_pre_hook_tlx(nargs):
824824
"BLOCK_M2": 128,
825825
"BLOCK_N2": 128,
826826
"NUM_BUFFERS_KV": 1,
827-
"NUM_BUFFERS_Q": 1,
827+
"NUM_BUFFERS_Q": 2,
828828
"NUM_BUFFERS_DO": 1,
829829
"NUM_BUFFERS_DS": 1,
830830
"NUM_BUFFERS_TMEM": 1,
831-
"EPILOGUE_SUBTILE": 2,
831+
"EPILOGUE_SUBTILE": 4,
832832
},
833833
num_warps=4,
834834
num_stages=1,
@@ -839,7 +839,7 @@ def _bwd_host_descriptor_pre_hook_tlx(nargs):
839839

840840
@triton.autotune(configs=configs_bwd_tlx, key=["N_CTX", "HEAD_DIM"])
841841
@triton.jit
842-
def _attn_bwd(
842+
def _attn_bwd_ws(
843843
desc_q,
844844
desc_k,
845845
desc_v,
@@ -874,18 +874,18 @@ def _attn_bwd(
874874
k_tiles = tlx.local_alloc((BLOCK_N1, HEAD_DIM), tlx.dtype_of(desc_k), NUM_BUFFERS_KV)
875875
v_tiles = tlx.local_alloc((BLOCK_N1, HEAD_DIM), tlx.dtype_of(desc_v), NUM_BUFFERS_KV)
876876
q_tiles = tlx.local_alloc((BLOCK_M1, HEAD_DIM), tlx.dtype_of(desc_q), NUM_BUFFERS_Q)
877-
do_tiles = tlx.local_alloc((BLOCK_M1, HEAD_DIM), tlx.dtype_of(desc_do), NUM_BUFFERS_Q)
877+
do_tiles = tlx.local_alloc((BLOCK_M1, HEAD_DIM), tlx.dtype_of(desc_do), NUM_BUFFERS_DO)
878878

879879
# Use SMEM for dsT
880-
ds_tiles = tlx.local_alloc((BLOCK_N1, BLOCK_M1), tlx.dtype_of(desc_q), NUM_BUFFERS_TMEM)
880+
ds_tiles = tlx.local_alloc((BLOCK_N1, BLOCK_M1), tlx.dtype_of(desc_q), NUM_BUFFERS_DS)
881881

882882
# allocate barriers for smem buffers
883883
k_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_KV)
884884
v_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_KV)
885885
q_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_Q)
886886
q_empties = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_Q)
887-
do_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_Q)
888-
do_empties = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_Q)
887+
do_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_DO)
888+
do_empties = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_DO)
889889
ds_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_TMEM)
890890

891891
# allocate tmem buffers
@@ -951,7 +951,7 @@ def _attn_bwd(
951951
curr_m += step_m
952952

953953
# compute
954-
with tlx.async_task(num_warps=4, registers=136, replicate=1):
954+
with tlx.async_task(num_warps=8, registers=192, replicate=1):
955955
off_chz, off_bh, start_m, start_n, num_steps = bwd_caculate_offsets(stride_z, stride_h, stride_tok, H,
956956
N_CTX, BLOCK_M1, BLOCK_N1)
957957

@@ -962,6 +962,7 @@ def _attn_bwd(
962962
step_m = BLOCK_M1
963963
for blk_idx in range(num_steps):
964964
tmem_buf_id, tmem_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_TMEM)
965+
ds_buf_id, _ = _get_bufidx_phase(blk_idx, NUM_BUFFERS_DS)
965966

966967
offs_m = curr_m + tl.arange(0, BLOCK_M1)
967968
m = tl.load(M + offs_m)
@@ -989,9 +990,9 @@ def _attn_bwd(
989990
# in the same iteration. Release dQ instead later.
990991
dsT = pT * (dpT - Di[None, :])
991992
dsT = dsT.to(tlx.dtype_of(desc_q))
992-
tlx.local_store(ds_tiles[tmem_buf_id], dsT)
993+
tlx.local_store(ds_tiles[ds_buf_id], dsT)
993994
tlx.fence_async_shared()
994-
tlx.barrier_arrive(ds_fulls[tmem_buf_id])
995+
tlx.barrier_arrive(ds_fulls[ds_buf_id])
995996
curr_m += step_m
996997

997998
# epilogue
@@ -1026,7 +1027,7 @@ def _attn_bwd(
10261027
)
10271028

10281029
# mma
1029-
with tlx.async_task(num_warps=1, registers=88):
1030+
with tlx.async_task(num_warps=1, registers=48):
10301031
_, _, start_m, _, num_steps = bwd_caculate_offsets(stride_z, stride_h, stride_tok, H, N_CTX, BLOCK_M1,
10311032
BLOCK_N1)
10321033

@@ -1038,10 +1039,66 @@ def _attn_bwd(
10381039
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
10391040
curr_m = start_m
10401041
step_m = BLOCK_M1
1041-
for blk_idx in range(num_steps):
1042+
blk_idx = 0
1043+
1044+
# -----------------------------------------------------------
1045+
# Prolog
1046+
#
1047+
# 1. qkT = tl.dot(k, qT)
1048+
# 2. dpT = tl.dot(v, tl.trans(do))
1049+
# 3. dv += tl.dot(ppT, do)
1050+
# -----------------------------------------------------------
1051+
1052+
q_buf_id, q_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_Q)
1053+
do_buf_id, do_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_DO)
1054+
tmem_buf_id, tmem_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_TMEM)
1055+
1056+
# Compute qkT = tl.dot(k, qT)
1057+
tlx.barrier_wait(q_fulls[q_buf_id], q_phase)
1058+
tlx.barrier_wait(qk_empties[tmem_buf_id], tmem_phase ^ 1)
1059+
qT = tlx.local_trans(q_tiles[q_buf_id])
1060+
tlx.async_dot(
1061+
k_tiles[0],
1062+
qT,
1063+
qk_tiles[tmem_buf_id],
1064+
use_acc=False,
1065+
mBarriers=[qk_fulls[tmem_buf_id]],
1066+
)
1067+
1068+
# Compute dpT = tl.dot(v, tl.trans(do))
1069+
tlx.barrier_wait(do_fulls[do_buf_id], do_phase)
1070+
# As dP uses the same tmem as dQ, wait for dQ release.
1071+
tlx.barrier_wait(dq_empties[tmem_buf_id], tmem_phase ^ 1)
1072+
doT = tlx.local_trans(do_tiles[do_buf_id])
1073+
tlx.async_dot(
1074+
v_tiles[0],
1075+
doT,
1076+
dp_tiles[tmem_buf_id],
1077+
use_acc=False,
1078+
mBarriers=[dp_fulls[tmem_buf_id]],
1079+
)
1080+
1081+
# Compute dv += tl.dot(ppT, do)
1082+
tlx.barrier_wait(p_fulls[tmem_buf_id], tmem_phase)
1083+
tlx.async_dot(
1084+
p_tiles[tmem_buf_id],
1085+
do_tiles[do_buf_id],
1086+
dv_tiles[tmem_buf_id],
1087+
use_acc=False,
1088+
mBarriers=[do_empties[do_buf_id]],
1089+
)
1090+
1091+
# -----------------------------------------------------------
1092+
# Main loop
1093+
# 1. qkT = tl.dot(k, qT)
1094+
# 2. dq = tl.dot(tl.trans(dsT), k) from previous iteration
1095+
# 3. dk += tl.dot(dsT, tl.trans(qT)) from previous iteration
1096+
# 4. dpT = tl.dot(v, tl.trans(do))
1097+
# 5. dv += tl.dot(ppT, do)
1098+
# -----------------------------------------------------------
1099+
for blk_idx in range(1, num_steps):
10421100
q_buf_id, q_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_Q)
10431101
tmem_buf_id, tmem_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_TMEM)
1044-
10451102
# Compute qkT = tl.dot(k, qT)
10461103
tlx.barrier_wait(q_fulls[q_buf_id], q_phase)
10471104
tlx.barrier_wait(qk_empties[tmem_buf_id], tmem_phase ^ 1)
@@ -1054,11 +1111,38 @@ def _attn_bwd(
10541111
mBarriers=[qk_fulls[tmem_buf_id]],
10551112
)
10561113

1114+
prev_blk_idx = blk_idx - 1
1115+
q_buf_id_prev, _ = _get_bufidx_phase(prev_blk_idx, NUM_BUFFERS_Q)
1116+
tmem_buf_id_prev, tmem_phase_prev = _get_bufidx_phase(prev_blk_idx, NUM_BUFFERS_TMEM)
1117+
ds_buf_id_prev, ds_phase_prev = _get_bufidx_phase(prev_blk_idx, NUM_BUFFERS_DS)
1118+
1119+
# Compute dq = tl.dot(tl.trans(dsT), k) from previous iteration
1120+
tlx.barrier_wait(ds_fulls[tmem_buf_id_prev], ds_phase_prev)
1121+
tlx.barrier_wait(dq_empties[tmem_buf_id_prev], tmem_phase_prev ^ 1)
1122+
dsT_view = tlx.local_trans(ds_tiles[ds_buf_id_prev])
1123+
tlx.async_dot(
1124+
dsT_view,
1125+
k_tiles[0],
1126+
dq_tiles[tmem_buf_id_prev],
1127+
use_acc=False,
1128+
mBarriers=[dq_fulls[tmem_buf_id_prev]],
1129+
)
1130+
1131+
# Compute dk += tl.dot(dsT, tl.trans(qT)) from previous iteration
1132+
tlx.async_dot(
1133+
ds_tiles[ds_buf_id_prev],
1134+
q_tiles[q_buf_id_prev],
1135+
dk_tiles[tmem_buf_id_prev],
1136+
use_acc=prev_blk_idx > 0,
1137+
mBarriers=[q_empties[q_buf_id_prev]],
1138+
)
1139+
1140+
do_buf_id, do_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_DO)
10571141
# Compute dpT = tl.dot(v, tl.trans(do))
1058-
tlx.barrier_wait(do_fulls[q_buf_id], q_phase)
1142+
tlx.barrier_wait(do_fulls[do_buf_id], do_phase)
10591143
# As dP uses the same tmem as dQ, wait for dQ release.
10601144
tlx.barrier_wait(dq_empties[tmem_buf_id], tmem_phase ^ 1)
1061-
doT = tlx.local_trans(do_tiles[q_buf_id])
1145+
doT = tlx.local_trans(do_tiles[do_buf_id])
10621146
tlx.async_dot(
10631147
v_tiles[0],
10641148
doT,
@@ -1071,63 +1155,89 @@ def _attn_bwd(
10711155
tlx.barrier_wait(p_fulls[tmem_buf_id], tmem_phase)
10721156
tlx.async_dot(
10731157
p_tiles[tmem_buf_id],
1074-
do_tiles[q_buf_id],
1158+
do_tiles[do_buf_id],
10751159
dv_tiles[tmem_buf_id],
1076-
use_acc=blk_idx > 0,
1077-
mBarriers=[do_empties[q_buf_id]],
1160+
use_acc=True,
1161+
mBarriers=[do_empties[do_buf_id]],
10781162
)
10791163

1080-
# Compute dk += tl.dot(dsT, tl.trans(qT))
1081-
tlx.barrier_wait(ds_fulls[tmem_buf_id], tmem_phase)
1082-
tlx.async_dot(
1083-
ds_tiles[tmem_buf_id],
1084-
q_tiles[q_buf_id],
1085-
dk_tiles[tmem_buf_id],
1086-
use_acc=blk_idx > 0,
1087-
mBarriers=[q_empties[tmem_buf_id]],
1088-
)
1164+
tlx.tcgen05_commit(dv_fulls[kv_buf_id])
10891165

1090-
# Compute dq = tl.dot(tl.trans(dsT), k)
1091-
tlx.barrier_wait(dq_empties[tmem_buf_id], tmem_phase ^ 1)
1092-
dsT_view = tlx.local_trans(ds_tiles[tmem_buf_id])
1093-
tlx.async_dot(
1094-
dsT_view,
1095-
k_tiles[0],
1096-
dq_tiles[tmem_buf_id],
1097-
use_acc=False,
1098-
mBarriers=[dq_fulls[tmem_buf_id]],
1099-
)
1166+
# -----------------------------------------------------------
1167+
# Epilog
1168+
# 4. dk += tl.dot(dsT, tl.trans(qT))
1169+
# 5. dq = tl.dot(tl.trans(dsT), k)
1170+
# -----------------------------------------------------------
1171+
q_buf_id, _ = _get_bufidx_phase(num_steps - 1, NUM_BUFFERS_Q)
1172+
tmem_buf_id, tmem_phase = _get_bufidx_phase(num_steps - 1, NUM_BUFFERS_TMEM)
1173+
ds_buf_id, ds_phase = _get_bufidx_phase(num_steps - 1, NUM_BUFFERS_DS)
1174+
# Compute dk += tl.dot(dsT, tl.trans(qT))
1175+
tlx.barrier_wait(ds_fulls[tmem_buf_id], ds_phase)
1176+
tlx.async_dot(
1177+
ds_tiles[ds_buf_id],
1178+
q_tiles[q_buf_id],
1179+
dk_tiles[tmem_buf_id],
1180+
use_acc=num_steps > 1,
1181+
mBarriers=[q_empties[q_buf_id], dk_fulls[kv_buf_id]],
1182+
)
11001183

1101-
tlx.tcgen05_commit(dv_fulls[kv_buf_id])
1102-
tlx.tcgen05_commit(dk_fulls[kv_buf_id])
1184+
# Compute dq = tl.dot(tl.trans(dsT), k)
1185+
tlx.barrier_wait(dq_empties[tmem_buf_id], tmem_phase ^ 1)
1186+
dsT_view = tlx.local_trans(ds_tiles[ds_buf_id])
1187+
tlx.async_dot(
1188+
dsT_view,
1189+
k_tiles[0],
1190+
dq_tiles[tmem_buf_id],
1191+
use_acc=False,
1192+
mBarriers=[dq_fulls[tmem_buf_id]],
1193+
)
11031194

11041195
# load
11051196
with tlx.async_task(num_warps=1, registers=88):
11061197
_, off_bh, start_m, start_n, num_steps = bwd_caculate_offsets(stride_z, stride_h, stride_tok, H, N_CTX,
11071198
BLOCK_M1, BLOCK_N1)
1108-
1199+
# Load K
11091200
kv_buf_id, _ = _get_bufidx_phase(0, NUM_BUFFERS_KV)
11101201
tlx.barrier_expect_bytes(k_fulls[kv_buf_id], 2 * BLOCK_N1 * HEAD_DIM) # float16
11111202
tlx.async_descriptor_load(desc_k, k_tiles[kv_buf_id], [(off_bh + start_n).to(tl.int32), 0],
11121203
k_fulls[kv_buf_id])
11131204

1205+
# Load Q
1206+
curr_m = start_m
1207+
step_m = BLOCK_M1
1208+
blk_idx = 0
1209+
q_buf_id, q_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_Q)
1210+
tlx.barrier_wait(q_empties[q_buf_id], q_phase ^ 1)
1211+
tlx.barrier_expect_bytes(q_fulls[q_buf_id], 2 * BLOCK_M1 * HEAD_DIM)
1212+
tlx.async_descriptor_load(desc_q, q_tiles[q_buf_id], [(off_bh + curr_m).to(tl.int32), 0], q_fulls[q_buf_id])
1213+
1214+
# Load V
11141215
tlx.barrier_expect_bytes(v_fulls[kv_buf_id], 2 * BLOCK_N1 * HEAD_DIM) # float16
11151216
tlx.async_descriptor_load(desc_v, v_tiles[kv_buf_id], [(off_bh + start_n).to(tl.int32), 0],
11161217
v_fulls[kv_buf_id])
11171218

1118-
curr_m = start_m
1119-
step_m = BLOCK_M1
1120-
for blk_idx in range(num_steps):
1219+
# Load dO
1220+
do_buf_id, do_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_DO)
1221+
tlx.barrier_wait(do_empties[do_buf_id], do_phase ^ 1)
1222+
tlx.barrier_expect_bytes(do_fulls[do_buf_id], 2 * BLOCK_M1 * HEAD_DIM)
1223+
tlx.async_descriptor_load(desc_do, do_tiles[do_buf_id], [(off_bh + curr_m).to(tl.int32), 0],
1224+
do_fulls[do_buf_id])
1225+
curr_m += step_m
1226+
1227+
for blk_idx in range(1, num_steps):
11211228
q_buf_id, q_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_Q)
1229+
do_buf_id, do_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_DO)
1230+
# Load Q
11221231
tlx.barrier_wait(q_empties[q_buf_id], q_phase ^ 1)
11231232
tlx.barrier_expect_bytes(q_fulls[q_buf_id], 2 * BLOCK_M1 * HEAD_DIM)
11241233
tlx.async_descriptor_load(desc_q, q_tiles[q_buf_id], [(off_bh + curr_m).to(tl.int32), 0],
11251234
q_fulls[q_buf_id])
11261235

1127-
tlx.barrier_wait(do_empties[q_buf_id], q_phase ^ 1)
1128-
tlx.barrier_expect_bytes(do_fulls[q_buf_id], 2 * BLOCK_M1 * HEAD_DIM)
1129-
tlx.async_descriptor_load(desc_do, do_tiles[q_buf_id], [(off_bh + curr_m).to(tl.int32), 0],
1130-
do_fulls[q_buf_id])
1236+
# Load dO
1237+
tlx.barrier_wait(do_empties[do_buf_id], do_phase ^ 1)
1238+
tlx.barrier_expect_bytes(do_fulls[do_buf_id], 2 * BLOCK_M1 * HEAD_DIM)
1239+
tlx.async_descriptor_load(desc_do, do_tiles[do_buf_id], [(off_bh + curr_m).to(tl.int32), 0],
1240+
do_fulls[do_buf_id])
11311241
curr_m += step_m
11321242

11331243

@@ -1241,7 +1351,7 @@ def grid(meta):
12411351
1, # (or cdiv over M if you need)
12421352
BATCH * N_HEAD) # batch*heads
12431353

1244-
_attn_bwd[grid](
1354+
_attn_bwd_ws[grid](
12451355
desc_q, desc_k, desc_v, ctx.sm_scale, desc_do, desc_dq, desc_dk, desc_dv, #
12461356
M, delta, #
12471357
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #

0 commit comments

Comments
 (0)