Skip to content

Commit 29b4513

Browse files
njriasanmeta-codesync[bot]
authored andcommitted
[TLX] Provide a basic implementation of FA causal backwards (#713)
Summary: Adds a basic implementation for backwards. It needs significant optimization. Pull Request resolved: #713 Reviewed By: srivatsan-ramesh, htyu Differential Revision: D88263408 Pulled By: njriasan fbshipit-source-id: 104895a1e944bdf45404e92fca7f22f03227068c
1 parent 9bf80fe commit 29b4513

File tree

1 file changed

+176
-49
lines changed

1 file changed

+176
-49
lines changed

third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py

Lines changed: 176 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,20 @@ def _get_unfused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr):
166166
return lo, hi
167167

168168

169+
@triton.jit
170+
def _get_unfused_bwd_loop_bounds(N_CTX, BLOCK_M, STAGE: tl.constexpr):
171+
if STAGE == 1:
172+
# First part of STAGE == 3
173+
lo, hi = 0, N_CTX
174+
elif STAGE == 2:
175+
# Second part of STAGE == 3 in this function
176+
lo, hi = N_CTX, N_CTX
177+
else:
178+
tl.static_assert(STAGE == 3)
179+
lo, hi = 0, N_CTX
180+
return lo, hi
181+
182+
169183
@triton.jit
170184
def _get_fused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr):
171185
if STAGE == 1:
@@ -867,7 +881,7 @@ def _attn_fwd_ws(sm_scale, M, #
867881
@triton.jit
868882
def _attn_bwd_preprocess(O, DO, #
869883
Delta, #
870-
Z, H, N_CTX, #
884+
N_CTX, #
871885
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, #
872886
):
873887
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
@@ -1107,6 +1121,77 @@ def _bwd_host_descriptor_pre_hook_tlx(nargs):
11071121
]
11081122

11091123

1124+
# TODO: Unused. Fix layout issue inside TLX.
1125+
@triton.jit
1126+
def _bwd_compute_inner_loop(
1127+
start_n,
1128+
qk_fulls,
1129+
qk_tiles,
1130+
qk_empties,
1131+
p_tiles,
1132+
p_fulls,
1133+
dp_fulls,
1134+
dp_tiles,
1135+
ds_tiles,
1136+
ds_fulls,
1137+
M,
1138+
D,
1139+
curr_m,
1140+
blk_idx,
1141+
step_m,
1142+
do_out_dtype,
1143+
q_out_dtype,
1144+
N_CTX,
1145+
NUM_BUFFERS_TMEM: tl.constexpr,
1146+
NUM_BUFFERS_DS: tl.constexpr,
1147+
BLOCK_M1: tl.constexpr,
1148+
BLOCK_N1: tl.constexpr,
1149+
STAGE: tl.constexpr,
1150+
):
1151+
offs_n = start_n + tl.arange(0, BLOCK_N1)
1152+
lo, hi = _get_unfused_bwd_loop_bounds(N_CTX, BLOCK_M1, STAGE)
1153+
num_steps = (hi - lo) // BLOCK_M1
1154+
for _ in range(num_steps):
1155+
tmem_buf_id, tmem_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_TMEM)
1156+
ds_buf_id, _ = _get_bufidx_phase(blk_idx, NUM_BUFFERS_DS)
1157+
1158+
offs_m = curr_m + tl.arange(0, BLOCK_M1)
1159+
m = tl.load(M + offs_m)
1160+
1161+
# wait for qkT = tl.dot(k, qT)
1162+
tlx.barrier_wait(tlx.local_view(qk_fulls, tmem_buf_id), tmem_phase)
1163+
qkT = tlx.local_load(tlx.local_view(qk_tiles, tmem_buf_id))
1164+
tlx.barrier_arrive(tlx.local_view(qk_empties, tmem_buf_id))
1165+
1166+
pT = tl.math.exp2(qkT - m[None, :])
1167+
if STAGE == 1:
1168+
mask = offs_m[None, :] >= offs_n[:, None]
1169+
pT = tl.where(mask, pT, 0.0)
1170+
1171+
# ppT *= qk_scale
1172+
ppT = pT
1173+
ppT = ppT.to(do_out_dtype)
1174+
tlx.local_store(tlx.local_view(p_tiles, tmem_buf_id), ppT)
1175+
tlx.barrier_arrive(tlx.local_view(p_fulls, tmem_buf_id))
1176+
1177+
# D (= delta) is pre-divided by ds_scale.
1178+
Di = tl.load(D + offs_m)
1179+
1180+
# Wait for dpT = tl.dot(v, tl.trans(do))
1181+
tlx.barrier_wait(tlx.local_view(dp_fulls, tmem_buf_id), tmem_phase)
1182+
dpT = tlx.local_load(tlx.local_view(dp_tiles, tmem_buf_id))
1183+
# No need to release dP, as dP uses the same tmem as dQ
1184+
# in the same iteration. Release dQ instead later.
1185+
dsT = pT * (dpT - Di[None, :])
1186+
dsT = dsT.to(q_out_dtype)
1187+
tlx.local_store(tlx.local_view(ds_tiles, ds_buf_id), dsT)
1188+
tlx.fence_async_shared()
1189+
tlx.barrier_arrive(tlx.local_view(ds_fulls, ds_buf_id))
1190+
curr_m += step_m
1191+
blk_idx += 1
1192+
return curr_m, blk_idx
1193+
1194+
11101195
@triton.autotune(configs=configs_bwd_tlx, key=["N_CTX", "HEAD_DIM"])
11111196
@triton.jit
11121197
def _attn_bwd_ws(
@@ -1139,6 +1224,7 @@ def _attn_bwd_ws(
11391224
NUM_BUFFERS_DS: tl.constexpr,
11401225
NUM_BUFFERS_TMEM: tl.constexpr,
11411226
EPILOGUE_SUBTILE: tl.constexpr,
1227+
STAGE: tl.constexpr,
11421228
):
11431229
# allocate smem buffers
11441230
k_tiles = tlx.local_alloc((BLOCK_N1, HEAD_DIM), tlx.dtype_of(desc_k), NUM_BUFFERS_KV)
@@ -1200,8 +1286,15 @@ def _attn_bwd_ws(
12001286
with tlx.async_tasks():
12011287
# reduction
12021288
with tlx.async_task("default"):
1203-
off_chz, off_bh, start_m, _, num_steps = bwd_caculate_offsets(stride_z, stride_h, stride_tok, H, N_CTX,
1204-
BLOCK_M1, BLOCK_N1)
1289+
off_chz, off_bh, start_m, _, num_steps = bwd_caculate_offsets(
1290+
stride_z,
1291+
stride_h,
1292+
stride_tok,
1293+
H,
1294+
N_CTX,
1295+
BLOCK_M1,
1296+
BLOCK_N1,
1297+
)
12051298
curr_m = start_m
12061299
step_m = BLOCK_M1
12071300
for blk_idx in range(num_steps):
@@ -1227,49 +1320,67 @@ def _attn_bwd_ws(
12271320

12281321
# compute
12291322
with tlx.async_task(num_warps=8, registers=192, replicate=1):
1230-
off_chz, off_bh, start_m, start_n, num_steps = bwd_caculate_offsets(stride_z, stride_h, stride_tok, H,
1231-
N_CTX, BLOCK_M1, BLOCK_N1)
1232-
1323+
off_chz, off_bh, _, start_n, _ = bwd_caculate_offsets(
1324+
stride_z,
1325+
stride_h,
1326+
stride_tok,
1327+
H,
1328+
N_CTX,
1329+
BLOCK_M1,
1330+
BLOCK_N1,
1331+
)
12331332
# offset pointers for batch/head
12341333
M += off_chz
12351334
D += off_chz
1236-
curr_m = start_m
1335+
curr_m = 0
12371336
step_m = BLOCK_M1
1238-
for blk_idx in range(num_steps):
1239-
tmem_buf_id, tmem_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_TMEM)
1240-
ds_buf_id, _ = _get_bufidx_phase(blk_idx, NUM_BUFFERS_DS)
1241-
1242-
offs_m = curr_m + tl.arange(0, BLOCK_M1)
1243-
m = tl.load(M + offs_m)
1244-
1245-
# wait for qkT = tl.dot(k, qT)
1246-
tlx.barrier_wait(qk_fulls[tmem_buf_id], tmem_phase)
1247-
qkT = tlx.local_load(qk_tiles[tmem_buf_id])
1248-
tlx.barrier_arrive(qk_empties[tmem_buf_id])
1249-
1250-
pT = tl.math.exp2(qkT - m[None, :])
1251-
1252-
# ppT *= qk_scale
1253-
ppT = pT
1254-
ppT = ppT.to(tlx.dtype_of(desc_do))
1255-
tlx.local_store(p_tiles[tmem_buf_id], ppT)
1256-
tlx.barrier_arrive(p_fulls[tmem_buf_id])
1257-
1258-
# D (= delta) is pre-divided by ds_scale.
1259-
Di = tl.load(D + offs_m)
1260-
1261-
# Wait for dpT = tl.dot(v, tl.trans(do))
1262-
tlx.barrier_wait(dp_fulls[tmem_buf_id], tmem_phase)
1263-
dpT = tlx.local_load(dp_tiles[tmem_buf_id])
1264-
# No need to release dP, as dP uses the same tmem as dQ
1265-
# in the same iteration. Release dQ instead later.
1266-
dsT = pT * (dpT - Di[None, :])
1267-
dsT = dsT.to(tlx.dtype_of(desc_q))
1268-
tlx.local_store(ds_tiles[ds_buf_id], dsT)
1269-
tlx.fence_async_shared()
1270-
tlx.barrier_arrive(ds_fulls[ds_buf_id])
1271-
curr_m += step_m
1272-
1337+
do_out_dtype = tlx.dtype_of(desc_do)
1338+
q_out_dtype = tlx.dtype_of(desc_q)
1339+
if STAGE & 1:
1340+
offs_n = start_n + tl.arange(0, BLOCK_N1)
1341+
lo, hi = _get_unfused_bwd_loop_bounds(N_CTX, BLOCK_M1, STAGE=4 - STAGE)
1342+
num_steps = (hi - lo) // BLOCK_M1
1343+
blk_idx = 0
1344+
for _ in range(num_steps):
1345+
tmem_buf_id, tmem_phase = _get_bufidx_phase(blk_idx, NUM_BUFFERS_TMEM)
1346+
ds_buf_id, _ = _get_bufidx_phase(blk_idx, NUM_BUFFERS_DS)
1347+
1348+
offs_m = curr_m + tl.arange(0, BLOCK_M1)
1349+
m = tl.load(M + offs_m)
1350+
1351+
# wait for qkT = tl.dot(k, qT)
1352+
tlx.barrier_wait(tlx.local_view(qk_fulls, tmem_buf_id), tmem_phase)
1353+
qkT = tlx.local_load(tlx.local_view(qk_tiles, tmem_buf_id))
1354+
tlx.barrier_arrive(tlx.local_view(qk_empties, tmem_buf_id))
1355+
1356+
pT = tl.math.exp2(qkT - m[None, :])
1357+
if STAGE == 3:
1358+
mask = offs_m[None, :] >= offs_n[:, None]
1359+
pT = tl.where(mask, pT, 0.0)
1360+
1361+
# ppT *= qk_scale
1362+
ppT = pT
1363+
ppT = ppT.to(do_out_dtype)
1364+
tlx.local_store(tlx.local_view(p_tiles, tmem_buf_id), ppT)
1365+
tlx.barrier_arrive(tlx.local_view(p_fulls, tmem_buf_id))
1366+
1367+
# D (= delta) is pre-divided by ds_scale.
1368+
Di = tl.load(D + offs_m)
1369+
1370+
# Wait for dpT = tl.dot(v, tl.trans(do))
1371+
tlx.barrier_wait(tlx.local_view(dp_fulls, tmem_buf_id), tmem_phase)
1372+
dpT = tlx.local_load(tlx.local_view(dp_tiles, tmem_buf_id))
1373+
# No need to release dP, as dP uses the same tmem as dQ
1374+
# in the same iteration. Release dQ instead later.
1375+
dsT = pT * (dpT - Di[None, :])
1376+
dsT = dsT.to(q_out_dtype)
1377+
tlx.local_store(tlx.local_view(ds_tiles, ds_buf_id), dsT)
1378+
tlx.fence_async_shared()
1379+
tlx.barrier_arrive(tlx.local_view(ds_fulls, ds_buf_id))
1380+
curr_m += step_m
1381+
blk_idx += 1
1382+
# TODO: Add the STAGE & 2 handling when we can determine bounds to divide
1383+
# the work across two loops, based on optimizing out the mask.
12731384
# epilogue
12741385
kv_buf_id, kv_phase = _get_bufidx_phase(0, NUM_BUFFERS_KV)
12751386

@@ -1303,8 +1414,15 @@ def _attn_bwd_ws(
13031414

13041415
# mma
13051416
with tlx.async_task(num_warps=1, registers=48):
1306-
_, _, start_m, _, num_steps = bwd_caculate_offsets(stride_z, stride_h, stride_tok, H, N_CTX, BLOCK_M1,
1307-
BLOCK_N1)
1417+
_, _, start_m, _, num_steps = bwd_caculate_offsets(
1418+
stride_z,
1419+
stride_h,
1420+
stride_tok,
1421+
H,
1422+
N_CTX,
1423+
BLOCK_M1,
1424+
BLOCK_N1,
1425+
)
13081426

13091427
kv_buf_id, kv_phase = _get_bufidx_phase(0, NUM_BUFFERS_KV)
13101428
tlx.barrier_wait(k_fulls[kv_buf_id], kv_phase)
@@ -1469,8 +1587,15 @@ def _attn_bwd_ws(
14691587

14701588
# load
14711589
with tlx.async_task(num_warps=1, registers=88):
1472-
_, off_bh, start_m, start_n, num_steps = bwd_caculate_offsets(stride_z, stride_h, stride_tok, H, N_CTX,
1473-
BLOCK_M1, BLOCK_N1)
1590+
_, off_bh, start_m, start_n, num_steps = bwd_caculate_offsets(
1591+
stride_z,
1592+
stride_h,
1593+
stride_tok,
1594+
H,
1595+
N_CTX,
1596+
BLOCK_M1,
1597+
BLOCK_N1,
1598+
)
14741599
# Load K
14751600
kv_buf_id, _ = _get_bufidx_phase(0, NUM_BUFFERS_KV)
14761601
tlx.barrier_expect_bytes(k_fulls[kv_buf_id], 2 * BLOCK_N1 * HEAD_DIM) # float16
@@ -1632,6 +1757,7 @@ def grid(META):
16321757
ctx.save_for_backward(q, k, v, o, M)
16331758
ctx.sm_scale = sm_scale
16341759
ctx.HEAD_DIM = HEAD_DIM_K
1760+
ctx.causal = causal
16351761
return o
16361762

16371763
@staticmethod
@@ -1655,7 +1781,7 @@ def backward(ctx, do):
16551781
_attn_bwd_preprocess[pre_grid](
16561782
o, do, #
16571783
delta, #
1658-
BATCH, N_HEAD, N_CTX, #
1784+
N_CTX, #
16591785
BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM, #
16601786
)
16611787

@@ -1716,13 +1842,16 @@ def grid(meta):
17161842
BATCH * N_HEAD,
17171843
) # batch*heads
17181844

1845+
stage = 3 if ctx.causal else 1
1846+
17191847
_attn_bwd_ws[grid](
17201848
desc_q, desc_k, desc_v, ctx.sm_scale, desc_do, desc_dq, desc_dk, desc_dv, #
17211849
M, delta, #
17221850
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
17231851
N_HEAD, N_CTX, #
17241852
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
17251853
HEAD_DIM=ctx.HEAD_DIM, #
1854+
STAGE=stage, #
17261855
)
17271856

17281857
return dq, dk, dv, None, None, None, None
@@ -1750,8 +1879,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, causal, dtype=torch.float16):
17501879
sm_scale = 0.5
17511880
# reference implementation
17521881
ref_dtype = dtype
1753-
if mode == "bwd" and causal:
1754-
pytest.skip("Causal not supported for bwd yet")
17551882
if mode == "fwd" and "fp8" in provider:
17561883
ref_dtype = torch.float32
17571884
q = q.to(ref_dtype)

0 commit comments

Comments
 (0)