@@ -166,6 +166,20 @@ def _get_unfused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr):
166166 return lo , hi
167167
168168
169+ @triton .jit
170+ def _get_unfused_bwd_loop_bounds (N_CTX , BLOCK_M , STAGE : tl .constexpr ):
171+ if STAGE == 1 :
172+ # First part of STAGE == 3
173+ lo , hi = 0 , N_CTX
174+ elif STAGE == 2 :
175+ # Second part of STAGE == 3 in this function
176+ lo , hi = N_CTX , N_CTX
177+ else :
178+ tl .static_assert (STAGE == 3 )
179+ lo , hi = 0 , N_CTX
180+ return lo , hi
181+
182+
169183@triton .jit
170184def _get_fused_loop_bounds (start_m , N_CTX , BLOCK_M , STAGE : tl .constexpr ):
171185 if STAGE == 1 :
@@ -867,7 +881,7 @@ def _attn_fwd_ws(sm_scale, M, #
867881@triton .jit
868882def _attn_bwd_preprocess (O , DO , #
869883 Delta , #
870- Z , H , N_CTX , #
884+ N_CTX , #
871885 BLOCK_M : tl .constexpr , HEAD_DIM : tl .constexpr , #
872886 ):
873887 off_m = tl .program_id (0 ) * BLOCK_M + tl .arange (0 , BLOCK_M )
@@ -1107,6 +1121,77 @@ def _bwd_host_descriptor_pre_hook_tlx(nargs):
11071121]
11081122
11091123
1124+ # TODO: Unused. Fix layout issue inside TLX.
1125+ @triton .jit
1126+ def _bwd_compute_inner_loop (
1127+ start_n ,
1128+ qk_fulls ,
1129+ qk_tiles ,
1130+ qk_empties ,
1131+ p_tiles ,
1132+ p_fulls ,
1133+ dp_fulls ,
1134+ dp_tiles ,
1135+ ds_tiles ,
1136+ ds_fulls ,
1137+ M ,
1138+ D ,
1139+ curr_m ,
1140+ blk_idx ,
1141+ step_m ,
1142+ do_out_dtype ,
1143+ q_out_dtype ,
1144+ N_CTX ,
1145+ NUM_BUFFERS_TMEM : tl .constexpr ,
1146+ NUM_BUFFERS_DS : tl .constexpr ,
1147+ BLOCK_M1 : tl .constexpr ,
1148+ BLOCK_N1 : tl .constexpr ,
1149+ STAGE : tl .constexpr ,
1150+ ):
1151+ offs_n = start_n + tl .arange (0 , BLOCK_N1 )
1152+ lo , hi = _get_unfused_bwd_loop_bounds (N_CTX , BLOCK_M1 , STAGE )
1153+ num_steps = (hi - lo ) // BLOCK_M1
1154+ for _ in range (num_steps ):
1155+ tmem_buf_id , tmem_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_TMEM )
1156+ ds_buf_id , _ = _get_bufidx_phase (blk_idx , NUM_BUFFERS_DS )
1157+
1158+ offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
1159+ m = tl .load (M + offs_m )
1160+
1161+ # wait for qkT = tl.dot(k, qT)
1162+ tlx .barrier_wait (tlx .local_view (qk_fulls , tmem_buf_id ), tmem_phase )
1163+ qkT = tlx .local_load (tlx .local_view (qk_tiles , tmem_buf_id ))
1164+ tlx .barrier_arrive (tlx .local_view (qk_empties , tmem_buf_id ))
1165+
1166+ pT = tl .math .exp2 (qkT - m [None , :])
1167+ if STAGE == 1 :
1168+ mask = offs_m [None , :] >= offs_n [:, None ]
1169+ pT = tl .where (mask , pT , 0.0 )
1170+
1171+ # ppT *= qk_scale
1172+ ppT = pT
1173+ ppT = ppT .to (do_out_dtype )
1174+ tlx .local_store (tlx .local_view (p_tiles , tmem_buf_id ), ppT )
1175+ tlx .barrier_arrive (tlx .local_view (p_fulls , tmem_buf_id ))
1176+
1177+ # D (= delta) is pre-divided by ds_scale.
1178+ Di = tl .load (D + offs_m )
1179+
1180+ # Wait for dpT = tl.dot(v, tl.trans(do))
1181+ tlx .barrier_wait (tlx .local_view (dp_fulls , tmem_buf_id ), tmem_phase )
1182+ dpT = tlx .local_load (tlx .local_view (dp_tiles , tmem_buf_id ))
1183+ # No need to release dP, as dP uses the same tmem as dQ
1184+ # in the same iteration. Release dQ instead later.
1185+ dsT = pT * (dpT - Di [None , :])
1186+ dsT = dsT .to (q_out_dtype )
1187+ tlx .local_store (tlx .local_view (ds_tiles , ds_buf_id ), dsT )
1188+ tlx .fence_async_shared ()
1189+ tlx .barrier_arrive (tlx .local_view (ds_fulls , ds_buf_id ))
1190+ curr_m += step_m
1191+ blk_idx += 1
1192+ return curr_m , blk_idx
1193+
1194+
11101195@triton .autotune (configs = configs_bwd_tlx , key = ["N_CTX" , "HEAD_DIM" ])
11111196@triton .jit
11121197def _attn_bwd_ws (
@@ -1139,6 +1224,7 @@ def _attn_bwd_ws(
11391224 NUM_BUFFERS_DS : tl .constexpr ,
11401225 NUM_BUFFERS_TMEM : tl .constexpr ,
11411226 EPILOGUE_SUBTILE : tl .constexpr ,
1227+ STAGE : tl .constexpr ,
11421228):
11431229 # allocate smem buffers
11441230 k_tiles = tlx .local_alloc ((BLOCK_N1 , HEAD_DIM ), tlx .dtype_of (desc_k ), NUM_BUFFERS_KV )
@@ -1200,8 +1286,15 @@ def _attn_bwd_ws(
12001286 with tlx .async_tasks ():
12011287 # reduction
12021288 with tlx .async_task ("default" ):
1203- off_chz , off_bh , start_m , _ , num_steps = bwd_caculate_offsets (stride_z , stride_h , stride_tok , H , N_CTX ,
1204- BLOCK_M1 , BLOCK_N1 )
1289+ off_chz , off_bh , start_m , _ , num_steps = bwd_caculate_offsets (
1290+ stride_z ,
1291+ stride_h ,
1292+ stride_tok ,
1293+ H ,
1294+ N_CTX ,
1295+ BLOCK_M1 ,
1296+ BLOCK_N1 ,
1297+ )
12051298 curr_m = start_m
12061299 step_m = BLOCK_M1
12071300 for blk_idx in range (num_steps ):
@@ -1227,49 +1320,67 @@ def _attn_bwd_ws(
12271320
12281321 # compute
12291322 with tlx .async_task (num_warps = 8 , registers = 192 , replicate = 1 ):
1230- off_chz , off_bh , start_m , start_n , num_steps = bwd_caculate_offsets (stride_z , stride_h , stride_tok , H ,
1231- N_CTX , BLOCK_M1 , BLOCK_N1 )
1232-
1323+ off_chz , off_bh , _ , start_n , _ = bwd_caculate_offsets (
1324+ stride_z ,
1325+ stride_h ,
1326+ stride_tok ,
1327+ H ,
1328+ N_CTX ,
1329+ BLOCK_M1 ,
1330+ BLOCK_N1 ,
1331+ )
12331332 # offset pointers for batch/head
12341333 M += off_chz
12351334 D += off_chz
1236- curr_m = start_m
1335+ curr_m = 0
12371336 step_m = BLOCK_M1
1238- for blk_idx in range (num_steps ):
1239- tmem_buf_id , tmem_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_TMEM )
1240- ds_buf_id , _ = _get_bufidx_phase (blk_idx , NUM_BUFFERS_DS )
1241-
1242- offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
1243- m = tl .load (M + offs_m )
1244-
1245- # wait for qkT = tl.dot(k, qT)
1246- tlx .barrier_wait (qk_fulls [tmem_buf_id ], tmem_phase )
1247- qkT = tlx .local_load (qk_tiles [tmem_buf_id ])
1248- tlx .barrier_arrive (qk_empties [tmem_buf_id ])
1249-
1250- pT = tl .math .exp2 (qkT - m [None , :])
1251-
1252- # ppT *= qk_scale
1253- ppT = pT
1254- ppT = ppT .to (tlx .dtype_of (desc_do ))
1255- tlx .local_store (p_tiles [tmem_buf_id ], ppT )
1256- tlx .barrier_arrive (p_fulls [tmem_buf_id ])
1257-
1258- # D (= delta) is pre-divided by ds_scale.
1259- Di = tl .load (D + offs_m )
1260-
1261- # Wait for dpT = tl.dot(v, tl.trans(do))
1262- tlx .barrier_wait (dp_fulls [tmem_buf_id ], tmem_phase )
1263- dpT = tlx .local_load (dp_tiles [tmem_buf_id ])
1264- # No need to release dP, as dP uses the same tmem as dQ
1265- # in the same iteration. Release dQ instead later.
1266- dsT = pT * (dpT - Di [None , :])
1267- dsT = dsT .to (tlx .dtype_of (desc_q ))
1268- tlx .local_store (ds_tiles [ds_buf_id ], dsT )
1269- tlx .fence_async_shared ()
1270- tlx .barrier_arrive (ds_fulls [ds_buf_id ])
1271- curr_m += step_m
1272-
1337+ do_out_dtype = tlx .dtype_of (desc_do )
1338+ q_out_dtype = tlx .dtype_of (desc_q )
1339+ if STAGE & 1 :
1340+ offs_n = start_n + tl .arange (0 , BLOCK_N1 )
1341+ lo , hi = _get_unfused_bwd_loop_bounds (N_CTX , BLOCK_M1 , STAGE = 4 - STAGE )
1342+ num_steps = (hi - lo ) // BLOCK_M1
1343+ blk_idx = 0
1344+ for _ in range (num_steps ):
1345+ tmem_buf_id , tmem_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_TMEM )
1346+ ds_buf_id , _ = _get_bufidx_phase (blk_idx , NUM_BUFFERS_DS )
1347+
1348+ offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
1349+ m = tl .load (M + offs_m )
1350+
1351+ # wait for qkT = tl.dot(k, qT)
1352+ tlx .barrier_wait (tlx .local_view (qk_fulls , tmem_buf_id ), tmem_phase )
1353+ qkT = tlx .local_load (tlx .local_view (qk_tiles , tmem_buf_id ))
1354+ tlx .barrier_arrive (tlx .local_view (qk_empties , tmem_buf_id ))
1355+
1356+ pT = tl .math .exp2 (qkT - m [None , :])
1357+ if STAGE == 3 :
1358+ mask = offs_m [None , :] >= offs_n [:, None ]
1359+ pT = tl .where (mask , pT , 0.0 )
1360+
1361+ # ppT *= qk_scale
1362+ ppT = pT
1363+ ppT = ppT .to (do_out_dtype )
1364+ tlx .local_store (tlx .local_view (p_tiles , tmem_buf_id ), ppT )
1365+ tlx .barrier_arrive (tlx .local_view (p_fulls , tmem_buf_id ))
1366+
1367+ # D (= delta) is pre-divided by ds_scale.
1368+ Di = tl .load (D + offs_m )
1369+
1370+ # Wait for dpT = tl.dot(v, tl.trans(do))
1371+ tlx .barrier_wait (tlx .local_view (dp_fulls , tmem_buf_id ), tmem_phase )
1372+ dpT = tlx .local_load (tlx .local_view (dp_tiles , tmem_buf_id ))
1373+ # No need to release dP, as dP uses the same tmem as dQ
1374+ # in the same iteration. Release dQ instead later.
1375+ dsT = pT * (dpT - Di [None , :])
1376+ dsT = dsT .to (q_out_dtype )
1377+ tlx .local_store (tlx .local_view (ds_tiles , ds_buf_id ), dsT )
1378+ tlx .fence_async_shared ()
1379+ tlx .barrier_arrive (tlx .local_view (ds_fulls , ds_buf_id ))
1380+ curr_m += step_m
1381+ blk_idx += 1
1382+ # TODO: Add the STAGE & 2 handling when we can determine bounds to divide
1383+ # the work across two loops, based on optimizing out the mask.
12731384 # epilogue
12741385 kv_buf_id , kv_phase = _get_bufidx_phase (0 , NUM_BUFFERS_KV )
12751386
@@ -1303,8 +1414,15 @@ def _attn_bwd_ws(
13031414
13041415 # mma
13051416 with tlx .async_task (num_warps = 1 , registers = 48 ):
1306- _ , _ , start_m , _ , num_steps = bwd_caculate_offsets (stride_z , stride_h , stride_tok , H , N_CTX , BLOCK_M1 ,
1307- BLOCK_N1 )
1417+ _ , _ , start_m , _ , num_steps = bwd_caculate_offsets (
1418+ stride_z ,
1419+ stride_h ,
1420+ stride_tok ,
1421+ H ,
1422+ N_CTX ,
1423+ BLOCK_M1 ,
1424+ BLOCK_N1 ,
1425+ )
13081426
13091427 kv_buf_id , kv_phase = _get_bufidx_phase (0 , NUM_BUFFERS_KV )
13101428 tlx .barrier_wait (k_fulls [kv_buf_id ], kv_phase )
@@ -1469,8 +1587,15 @@ def _attn_bwd_ws(
14691587
14701588 # load
14711589 with tlx .async_task (num_warps = 1 , registers = 88 ):
1472- _ , off_bh , start_m , start_n , num_steps = bwd_caculate_offsets (stride_z , stride_h , stride_tok , H , N_CTX ,
1473- BLOCK_M1 , BLOCK_N1 )
1590+ _ , off_bh , start_m , start_n , num_steps = bwd_caculate_offsets (
1591+ stride_z ,
1592+ stride_h ,
1593+ stride_tok ,
1594+ H ,
1595+ N_CTX ,
1596+ BLOCK_M1 ,
1597+ BLOCK_N1 ,
1598+ )
14741599 # Load K
14751600 kv_buf_id , _ = _get_bufidx_phase (0 , NUM_BUFFERS_KV )
14761601 tlx .barrier_expect_bytes (k_fulls [kv_buf_id ], 2 * BLOCK_N1 * HEAD_DIM ) # float16
@@ -1632,6 +1757,7 @@ def grid(META):
16321757 ctx .save_for_backward (q , k , v , o , M )
16331758 ctx .sm_scale = sm_scale
16341759 ctx .HEAD_DIM = HEAD_DIM_K
1760+ ctx .causal = causal
16351761 return o
16361762
16371763 @staticmethod
@@ -1655,7 +1781,7 @@ def backward(ctx, do):
16551781 _attn_bwd_preprocess [pre_grid ](
16561782 o , do , #
16571783 delta , #
1658- BATCH , N_HEAD , N_CTX , #
1784+ N_CTX , #
16591785 BLOCK_M = PRE_BLOCK , HEAD_DIM = ctx .HEAD_DIM , #
16601786 )
16611787
@@ -1716,13 +1842,16 @@ def grid(meta):
17161842 BATCH * N_HEAD ,
17171843 ) # batch*heads
17181844
1845+ stage = 3 if ctx .causal else 1
1846+
17191847 _attn_bwd_ws [grid ](
17201848 desc_q , desc_k , desc_v , ctx .sm_scale , desc_do , desc_dq , desc_dk , desc_dv , #
17211849 M , delta , #
17221850 q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
17231851 N_HEAD , N_CTX , #
17241852 BLK_SLICE_FACTOR = BLK_SLICE_FACTOR , #
17251853 HEAD_DIM = ctx .HEAD_DIM , #
1854+ STAGE = stage , #
17261855 )
17271856
17281857 return dq , dk , dv , None , None , None , None
@@ -1750,8 +1879,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, causal, dtype=torch.float16):
17501879 sm_scale = 0.5
17511880 # reference implementation
17521881 ref_dtype = dtype
1753- if mode == "bwd" and causal :
1754- pytest .skip ("Causal not supported for bwd yet" )
17551882 if mode == "fwd" and "fp8" in provider :
17561883 ref_dtype = torch .float32
17571884 q = q .to (ref_dtype )
0 commit comments