@@ -824,11 +824,11 @@ def _bwd_host_descriptor_pre_hook_tlx(nargs):
824824 "BLOCK_M2" : 128 ,
825825 "BLOCK_N2" : 128 ,
826826 "NUM_BUFFERS_KV" : 1 ,
827- "NUM_BUFFERS_Q" : 1 ,
827+ "NUM_BUFFERS_Q" : 2 ,
828828 "NUM_BUFFERS_DO" : 1 ,
829829 "NUM_BUFFERS_DS" : 1 ,
830830 "NUM_BUFFERS_TMEM" : 1 ,
831- "EPILOGUE_SUBTILE" : 2 ,
831+ "EPILOGUE_SUBTILE" : 4 ,
832832 },
833833 num_warps = 4 ,
834834 num_stages = 1 ,
@@ -839,7 +839,7 @@ def _bwd_host_descriptor_pre_hook_tlx(nargs):
839839
840840@triton .autotune (configs = configs_bwd_tlx , key = ["N_CTX" , "HEAD_DIM" ])
841841@triton .jit
842- def _attn_bwd (
842+ def _attn_bwd_ws (
843843 desc_q ,
844844 desc_k ,
845845 desc_v ,
@@ -874,18 +874,18 @@ def _attn_bwd(
874874 k_tiles = tlx .local_alloc ((BLOCK_N1 , HEAD_DIM ), tlx .dtype_of (desc_k ), NUM_BUFFERS_KV )
875875 v_tiles = tlx .local_alloc ((BLOCK_N1 , HEAD_DIM ), tlx .dtype_of (desc_v ), NUM_BUFFERS_KV )
876876 q_tiles = tlx .local_alloc ((BLOCK_M1 , HEAD_DIM ), tlx .dtype_of (desc_q ), NUM_BUFFERS_Q )
877- do_tiles = tlx .local_alloc ((BLOCK_M1 , HEAD_DIM ), tlx .dtype_of (desc_do ), NUM_BUFFERS_Q )
877+ do_tiles = tlx .local_alloc ((BLOCK_M1 , HEAD_DIM ), tlx .dtype_of (desc_do ), NUM_BUFFERS_DO )
878878
879879 # Use SMEM for dsT
880- ds_tiles = tlx .local_alloc ((BLOCK_N1 , BLOCK_M1 ), tlx .dtype_of (desc_q ), NUM_BUFFERS_TMEM )
880+ ds_tiles = tlx .local_alloc ((BLOCK_N1 , BLOCK_M1 ), tlx .dtype_of (desc_q ), NUM_BUFFERS_DS )
881881
882882 # allocate barriers for smem buffers
883883 k_fulls = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_KV )
884884 v_fulls = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_KV )
885885 q_fulls = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_Q )
886886 q_empties = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_Q )
887- do_fulls = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_Q )
888- do_empties = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_Q )
887+ do_fulls = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_DO )
888+ do_empties = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_DO )
889889 ds_fulls = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_TMEM )
890890
891891 # allocate tmem buffers
@@ -951,7 +951,7 @@ def _attn_bwd(
951951 curr_m += step_m
952952
953953 # compute
954- with tlx .async_task (num_warps = 4 , registers = 136 , replicate = 1 ):
954+ with tlx .async_task (num_warps = 8 , registers = 192 , replicate = 1 ):
955955 off_chz , off_bh , start_m , start_n , num_steps = bwd_caculate_offsets (stride_z , stride_h , stride_tok , H ,
956956 N_CTX , BLOCK_M1 , BLOCK_N1 )
957957
@@ -962,6 +962,7 @@ def _attn_bwd(
962962 step_m = BLOCK_M1
963963 for blk_idx in range (num_steps ):
964964 tmem_buf_id , tmem_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_TMEM )
965+ ds_buf_id , _ = _get_bufidx_phase (blk_idx , NUM_BUFFERS_DS )
965966
966967 offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
967968 m = tl .load (M + offs_m )
@@ -989,9 +990,9 @@ def _attn_bwd(
989990 # in the same iteration. Release dQ instead later.
990991 dsT = pT * (dpT - Di [None , :])
991992 dsT = dsT .to (tlx .dtype_of (desc_q ))
992- tlx .local_store (ds_tiles [tmem_buf_id ], dsT )
993+ tlx .local_store (ds_tiles [ds_buf_id ], dsT )
993994 tlx .fence_async_shared ()
994- tlx .barrier_arrive (ds_fulls [tmem_buf_id ])
995+ tlx .barrier_arrive (ds_fulls [ds_buf_id ])
995996 curr_m += step_m
996997
997998 # epilogue
@@ -1026,7 +1027,7 @@ def _attn_bwd(
10261027 )
10271028
10281029 # mma
1029- with tlx .async_task (num_warps = 1 , registers = 88 ):
1030+ with tlx .async_task (num_warps = 1 , registers = 48 ):
10301031 _ , _ , start_m , _ , num_steps = bwd_caculate_offsets (stride_z , stride_h , stride_tok , H , N_CTX , BLOCK_M1 ,
10311032 BLOCK_N1 )
10321033
@@ -1038,10 +1039,66 @@ def _attn_bwd(
10381039 tl .static_assert (BLOCK_N1 % BLOCK_M1 == 0 )
10391040 curr_m = start_m
10401041 step_m = BLOCK_M1
1041- for blk_idx in range (num_steps ):
1042+ blk_idx = 0
1043+
1044+ # -----------------------------------------------------------
1045+ # Prolog
1046+ #
1047+ # 1. qkT = tl.dot(k, qT)
1048+ # 2. dpT = tl.dot(v, tl.trans(do))
1049+ # 3. dv += tl.dot(ppT, do)
1050+ # -----------------------------------------------------------
1051+
1052+ q_buf_id , q_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_Q )
1053+ do_buf_id , do_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_DO )
1054+ tmem_buf_id , tmem_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_TMEM )
1055+
1056+ # Compute qkT = tl.dot(k, qT)
1057+ tlx .barrier_wait (q_fulls [q_buf_id ], q_phase )
1058+ tlx .barrier_wait (qk_empties [tmem_buf_id ], tmem_phase ^ 1 )
1059+ qT = tlx .local_trans (q_tiles [q_buf_id ])
1060+ tlx .async_dot (
1061+ k_tiles [0 ],
1062+ qT ,
1063+ qk_tiles [tmem_buf_id ],
1064+ use_acc = False ,
1065+ mBarriers = [qk_fulls [tmem_buf_id ]],
1066+ )
1067+
1068+ # Compute dpT = tl.dot(v, tl.trans(do))
1069+ tlx .barrier_wait (do_fulls [do_buf_id ], do_phase )
1070+ # As dP uses the same tmem as dQ, wait for dQ release.
1071+ tlx .barrier_wait (dq_empties [tmem_buf_id ], tmem_phase ^ 1 )
1072+ doT = tlx .local_trans (do_tiles [do_buf_id ])
1073+ tlx .async_dot (
1074+ v_tiles [0 ],
1075+ doT ,
1076+ dp_tiles [tmem_buf_id ],
1077+ use_acc = False ,
1078+ mBarriers = [dp_fulls [tmem_buf_id ]],
1079+ )
1080+
1081+ # Compute dv += tl.dot(ppT, do)
1082+ tlx .barrier_wait (p_fulls [tmem_buf_id ], tmem_phase )
1083+ tlx .async_dot (
1084+ p_tiles [tmem_buf_id ],
1085+ do_tiles [do_buf_id ],
1086+ dv_tiles [tmem_buf_id ],
1087+ use_acc = False ,
1088+ mBarriers = [do_empties [do_buf_id ]],
1089+ )
1090+
1091+ # -----------------------------------------------------------
1092+ # Main loop
1093+ # 1. qkT = tl.dot(k, qT)
1094+ # 2. dq = tl.dot(tl.trans(dsT), k) from previous iteration
1095+ # 3. dk += tl.dot(dsT, tl.trans(qT)) from previous iteration
1096+ # 4. dpT = tl.dot(v, tl.trans(do))
1097+ # 5. dv += tl.dot(ppT, do)
1098+ # -----------------------------------------------------------
1099+ for blk_idx in range (1 , num_steps ):
10421100 q_buf_id , q_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_Q )
10431101 tmem_buf_id , tmem_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_TMEM )
1044-
10451102 # Compute qkT = tl.dot(k, qT)
10461103 tlx .barrier_wait (q_fulls [q_buf_id ], q_phase )
10471104 tlx .barrier_wait (qk_empties [tmem_buf_id ], tmem_phase ^ 1 )
@@ -1054,11 +1111,38 @@ def _attn_bwd(
10541111 mBarriers = [qk_fulls [tmem_buf_id ]],
10551112 )
10561113
1114+ prev_blk_idx = blk_idx - 1
1115+ q_buf_id_prev , _ = _get_bufidx_phase (prev_blk_idx , NUM_BUFFERS_Q )
1116+ tmem_buf_id_prev , tmem_phase_prev = _get_bufidx_phase (prev_blk_idx , NUM_BUFFERS_TMEM )
1117+ ds_buf_id_prev , ds_phase_prev = _get_bufidx_phase (prev_blk_idx , NUM_BUFFERS_DS )
1118+
1119+ # Compute dq = tl.dot(tl.trans(dsT), k) from previous iteration
1120+ tlx .barrier_wait (ds_fulls [tmem_buf_id_prev ], ds_phase_prev )
1121+ tlx .barrier_wait (dq_empties [tmem_buf_id_prev ], tmem_phase_prev ^ 1 )
1122+ dsT_view = tlx .local_trans (ds_tiles [ds_buf_id_prev ])
1123+ tlx .async_dot (
1124+ dsT_view ,
1125+ k_tiles [0 ],
1126+ dq_tiles [tmem_buf_id_prev ],
1127+ use_acc = False ,
1128+ mBarriers = [dq_fulls [tmem_buf_id_prev ]],
1129+ )
1130+
1131+ # Compute dk += tl.dot(dsT, tl.trans(qT)) from previous iteration
1132+ tlx .async_dot (
1133+ ds_tiles [ds_buf_id_prev ],
1134+ q_tiles [q_buf_id_prev ],
1135+ dk_tiles [tmem_buf_id_prev ],
1136+ use_acc = prev_blk_idx > 0 ,
1137+ mBarriers = [q_empties [q_buf_id_prev ]],
1138+ )
1139+
1140+ do_buf_id , do_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_DO )
10571141 # Compute dpT = tl.dot(v, tl.trans(do))
1058- tlx .barrier_wait (do_fulls [q_buf_id ], q_phase )
1142+ tlx .barrier_wait (do_fulls [do_buf_id ], do_phase )
10591143 # As dP uses the same tmem as dQ, wait for dQ release.
10601144 tlx .barrier_wait (dq_empties [tmem_buf_id ], tmem_phase ^ 1 )
1061- doT = tlx .local_trans (do_tiles [q_buf_id ])
1145+ doT = tlx .local_trans (do_tiles [do_buf_id ])
10621146 tlx .async_dot (
10631147 v_tiles [0 ],
10641148 doT ,
@@ -1071,63 +1155,89 @@ def _attn_bwd(
10711155 tlx .barrier_wait (p_fulls [tmem_buf_id ], tmem_phase )
10721156 tlx .async_dot (
10731157 p_tiles [tmem_buf_id ],
1074- do_tiles [q_buf_id ],
1158+ do_tiles [do_buf_id ],
10751159 dv_tiles [tmem_buf_id ],
1076- use_acc = blk_idx > 0 ,
1077- mBarriers = [do_empties [q_buf_id ]],
1160+ use_acc = True ,
1161+ mBarriers = [do_empties [do_buf_id ]],
10781162 )
10791163
1080- # Compute dk += tl.dot(dsT, tl.trans(qT))
1081- tlx .barrier_wait (ds_fulls [tmem_buf_id ], tmem_phase )
1082- tlx .async_dot (
1083- ds_tiles [tmem_buf_id ],
1084- q_tiles [q_buf_id ],
1085- dk_tiles [tmem_buf_id ],
1086- use_acc = blk_idx > 0 ,
1087- mBarriers = [q_empties [tmem_buf_id ]],
1088- )
1164+ tlx .tcgen05_commit (dv_fulls [kv_buf_id ])
10891165
1090- # Compute dq = tl.dot(tl.trans(dsT), k)
1091- tlx .barrier_wait (dq_empties [tmem_buf_id ], tmem_phase ^ 1 )
1092- dsT_view = tlx .local_trans (ds_tiles [tmem_buf_id ])
1093- tlx .async_dot (
1094- dsT_view ,
1095- k_tiles [0 ],
1096- dq_tiles [tmem_buf_id ],
1097- use_acc = False ,
1098- mBarriers = [dq_fulls [tmem_buf_id ]],
1099- )
1166+ # -----------------------------------------------------------
1167+ # Epilog
1168+ # 4. dk += tl.dot(dsT, tl.trans(qT))
1169+ # 5. dq = tl.dot(tl.trans(dsT), k)
1170+ # -----------------------------------------------------------
1171+ q_buf_id , _ = _get_bufidx_phase (num_steps - 1 , NUM_BUFFERS_Q )
1172+ tmem_buf_id , tmem_phase = _get_bufidx_phase (num_steps - 1 , NUM_BUFFERS_TMEM )
1173+ ds_buf_id , ds_phase = _get_bufidx_phase (num_steps - 1 , NUM_BUFFERS_DS )
1174+ # Compute dk += tl.dot(dsT, tl.trans(qT))
1175+ tlx .barrier_wait (ds_fulls [tmem_buf_id ], ds_phase )
1176+ tlx .async_dot (
1177+ ds_tiles [ds_buf_id ],
1178+ q_tiles [q_buf_id ],
1179+ dk_tiles [tmem_buf_id ],
1180+ use_acc = num_steps > 1 ,
1181+ mBarriers = [q_empties [q_buf_id ], dk_fulls [kv_buf_id ]],
1182+ )
11001183
1101- tlx .tcgen05_commit (dv_fulls [kv_buf_id ])
1102- tlx .tcgen05_commit (dk_fulls [kv_buf_id ])
1184+ # Compute dq = tl.dot(tl.trans(dsT), k)
1185+ tlx .barrier_wait (dq_empties [tmem_buf_id ], tmem_phase ^ 1 )
1186+ dsT_view = tlx .local_trans (ds_tiles [ds_buf_id ])
1187+ tlx .async_dot (
1188+ dsT_view ,
1189+ k_tiles [0 ],
1190+ dq_tiles [tmem_buf_id ],
1191+ use_acc = False ,
1192+ mBarriers = [dq_fulls [tmem_buf_id ]],
1193+ )
11031194
11041195 # load
11051196 with tlx .async_task (num_warps = 1 , registers = 88 ):
11061197 _ , off_bh , start_m , start_n , num_steps = bwd_caculate_offsets (stride_z , stride_h , stride_tok , H , N_CTX ,
11071198 BLOCK_M1 , BLOCK_N1 )
1108-
1199+ # Load K
11091200 kv_buf_id , _ = _get_bufidx_phase (0 , NUM_BUFFERS_KV )
11101201 tlx .barrier_expect_bytes (k_fulls [kv_buf_id ], 2 * BLOCK_N1 * HEAD_DIM ) # float16
11111202 tlx .async_descriptor_load (desc_k , k_tiles [kv_buf_id ], [(off_bh + start_n ).to (tl .int32 ), 0 ],
11121203 k_fulls [kv_buf_id ])
11131204
1205+ # Load Q
1206+ curr_m = start_m
1207+ step_m = BLOCK_M1
1208+ blk_idx = 0
1209+ q_buf_id , q_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_Q )
1210+ tlx .barrier_wait (q_empties [q_buf_id ], q_phase ^ 1 )
1211+ tlx .barrier_expect_bytes (q_fulls [q_buf_id ], 2 * BLOCK_M1 * HEAD_DIM )
1212+ tlx .async_descriptor_load (desc_q , q_tiles [q_buf_id ], [(off_bh + curr_m ).to (tl .int32 ), 0 ], q_fulls [q_buf_id ])
1213+
1214+ # Load V
11141215 tlx .barrier_expect_bytes (v_fulls [kv_buf_id ], 2 * BLOCK_N1 * HEAD_DIM ) # float16
11151216 tlx .async_descriptor_load (desc_v , v_tiles [kv_buf_id ], [(off_bh + start_n ).to (tl .int32 ), 0 ],
11161217 v_fulls [kv_buf_id ])
11171218
1118- curr_m = start_m
1119- step_m = BLOCK_M1
1120- for blk_idx in range (num_steps ):
1219+ # Load dO
1220+ do_buf_id , do_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_DO )
1221+ tlx .barrier_wait (do_empties [do_buf_id ], do_phase ^ 1 )
1222+ tlx .barrier_expect_bytes (do_fulls [do_buf_id ], 2 * BLOCK_M1 * HEAD_DIM )
1223+ tlx .async_descriptor_load (desc_do , do_tiles [do_buf_id ], [(off_bh + curr_m ).to (tl .int32 ), 0 ],
1224+ do_fulls [do_buf_id ])
1225+ curr_m += step_m
1226+
1227+ for blk_idx in range (1 , num_steps ):
11211228 q_buf_id , q_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_Q )
1229+ do_buf_id , do_phase = _get_bufidx_phase (blk_idx , NUM_BUFFERS_DO )
1230+ # Load Q
11221231 tlx .barrier_wait (q_empties [q_buf_id ], q_phase ^ 1 )
11231232 tlx .barrier_expect_bytes (q_fulls [q_buf_id ], 2 * BLOCK_M1 * HEAD_DIM )
11241233 tlx .async_descriptor_load (desc_q , q_tiles [q_buf_id ], [(off_bh + curr_m ).to (tl .int32 ), 0 ],
11251234 q_fulls [q_buf_id ])
11261235
1127- tlx .barrier_wait (do_empties [q_buf_id ], q_phase ^ 1 )
1128- tlx .barrier_expect_bytes (do_fulls [q_buf_id ], 2 * BLOCK_M1 * HEAD_DIM )
1129- tlx .async_descriptor_load (desc_do , do_tiles [q_buf_id ], [(off_bh + curr_m ).to (tl .int32 ), 0 ],
1130- do_fulls [q_buf_id ])
1236+ # Load dO
1237+ tlx .barrier_wait (do_empties [do_buf_id ], do_phase ^ 1 )
1238+ tlx .barrier_expect_bytes (do_fulls [do_buf_id ], 2 * BLOCK_M1 * HEAD_DIM )
1239+ tlx .async_descriptor_load (desc_do , do_tiles [do_buf_id ], [(off_bh + curr_m ).to (tl .int32 ), 0 ],
1240+ do_fulls [do_buf_id ])
11311241 curr_m += step_m
11321242
11331243
@@ -1241,7 +1351,7 @@ def grid(meta):
12411351 1 , # (or cdiv over M if you need)
12421352 BATCH * N_HEAD ) # batch*heads
12431353
1244- _attn_bwd [grid ](
1354+ _attn_bwd_ws [grid ](
12451355 desc_q , desc_k , desc_v , ctx .sm_scale , desc_do , desc_dq , desc_dk , desc_dv , #
12461356 M , delta , #
12471357 q .stride (0 ), q .stride (1 ), q .stride (2 ), q .stride (3 ), #
0 commit comments