@@ -1208,13 +1208,15 @@ def backward_kernel(
12081208 BLOCK : tl .constexpr ,
12091209 QUERY_HEAD_GROUPS : tl .constexpr ,
12101210 QUERY_EXPAND_DIM : tl .constexpr ,
1211- NUM_SEL_KV_BLOCKS : tl .constexpr
12121211):
12131212 off_hb = tl .program_id (1 )
12141213 off_b = off_hb // kv_heads
12151214 off_h = off_hb % kv_heads
12161215 off_qh = off_h * QUERY_HEAD_GROUPS
12171216
1217+ IS_CAUSAL = tl .program_id (0 ) == 0
1218+ OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - 1
1219+
12181220 # offset pointers for batch/head
12191221
12201222 Q += off_b * stride_qb + off_qh * stride_qh
@@ -1244,46 +1246,47 @@ def backward_kernel(
12441246
12451247 num_block_n = tl .cdiv (seqlen_k , BLOCK )
12461248
1247- for start_n in range (0 , num_block_n ):
1248- backward_kernel_one_col_block_causal (
1249- start_n ,
1250- Q ,
1251- K ,
1252- V ,
1253- kv_block_indices ,
1254- kv_block_mask ,
1255- DO ,
1256- DQ ,
1257- DK ,
1258- DV ,
1259- LSE ,
1260- D ,
1261- softmax_scale ,
1262- stride_qm ,
1263- stride_kn ,
1264- stride_vn ,
1265- stride_dom ,
1266- stride_dqm ,
1267- stride_dkn ,
1268- stride_dvn ,
1269- stride_kvbl_m ,
1270- stride_qh ,
1271- stride_doh ,
1272- stride_dqh ,
1273- seqlen_q ,
1274- seqlen_k ,
1275- seqlen_q_rounded ,
1276- headdim ,
1277- BLOCK_HEADDIM = BLOCK_HEADDIM ,
1278- EVEN_M = EVEN_M ,
1279- EVEN_N = EVEN_N ,
1280- EVEN_HEADDIM = EVEN_HEADDIM ,
1281- BLOCK = BLOCK ,
1282- QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1283- QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1284- )
1285-
1286- for off_sel_kv_blocks in range (NUM_SEL_KV_BLOCKS ):
1249+ if IS_CAUSAL :
1250+ for start_n in range (0 , num_block_n ):
1251+ backward_kernel_one_col_block_causal (
1252+ start_n ,
1253+ Q ,
1254+ K ,
1255+ V ,
1256+ kv_block_indices ,
1257+ kv_block_mask ,
1258+ DO ,
1259+ DQ ,
1260+ DK ,
1261+ DV ,
1262+ LSE ,
1263+ D ,
1264+ softmax_scale ,
1265+ stride_qm ,
1266+ stride_kn ,
1267+ stride_vn ,
1268+ stride_dom ,
1269+ stride_dqm ,
1270+ stride_dkn ,
1271+ stride_dvn ,
1272+ stride_kvbl_m ,
1273+ stride_qh ,
1274+ stride_doh ,
1275+ stride_dqh ,
1276+ seqlen_q ,
1277+ seqlen_k ,
1278+ seqlen_q_rounded ,
1279+ headdim ,
1280+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
1281+ EVEN_M = EVEN_M ,
1282+ EVEN_N = EVEN_N ,
1283+ EVEN_HEADDIM = EVEN_HEADDIM ,
1284+ BLOCK = BLOCK ,
1285+ QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1286+ QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1287+ )
1288+ else :
1289+ for start_n in range (0 , num_block_n ):
12871290 backward_kernel_one_col_block_sparse (
12881291 start_n ,
12891292 Q ,
@@ -1320,7 +1323,7 @@ def backward_kernel(
13201323 BLOCK = BLOCK ,
13211324 QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
13221325 QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1323- OFF_SEL_KV_BLOCKS = off_sel_kv_blocks
1326+ OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS
13241327 )
13251328
13261329def native_sparse_attn_backward (
@@ -1383,7 +1386,7 @@ def native_sparse_attn_backward(
13831386 BLOCK_HEADDIM = BLOCK_HEADDIM ,
13841387 )
13851388
1386- grid = lambda META : (1 , batch * kv_heads )
1389+ grid = lambda META : (num_sel_fine_blocks + 1 , batch * kv_heads )
13871390
13881391 backward_kernel [grid ](
13891392 q ,
@@ -1437,7 +1440,6 @@ def native_sparse_attn_backward(
14371440 BLOCK = block_size ,
14381441 QUERY_HEAD_GROUPS = head_groups ,
14391442 QUERY_EXPAND_DIM = 16 // head_groups ,
1440- NUM_SEL_KV_BLOCKS = num_sel_fine_blocks ,
14411443 EVEN_M = divisible_by (seqlen_q , block_size ),
14421444 EVEN_N = divisible_by (seqlen_k , block_size ),
14431445 EVEN_HEADDIM = BLOCK_HEADDIM == dim
0 commit comments