@@ -1526,93 +1526,91 @@ def backward_kernel(
15261526 off_qh * seqlen_q_rounded
15271527 )
15281528
1529- num_block_n = tl .cdiv ( seqlen_k , BLOCK )
1529+ start_n = tl .program_id ( 2 )
15301530
15311531 if IS_CAUSAL :
1532- for start_n in range (0 , num_block_n ):
1533- backward_kernel_one_col_block_causal (
1534- start_n ,
1535- Q ,
1536- K ,
1537- V ,
1538- kv_block_indices ,
1539- kv_block_mask ,
1540- DO ,
1541- DQ ,
1542- DK ,
1543- DV ,
1544- LSE ,
1545- D ,
1546- softmax_scale ,
1547- stride_qm ,
1548- stride_kn ,
1549- stride_vn ,
1550- stride_dom ,
1551- stride_dqm ,
1552- stride_dkn ,
1553- stride_dvn ,
1554- stride_kvbl_m ,
1555- stride_qh ,
1556- stride_doh ,
1557- stride_dqh ,
1558- seqlen_q ,
1559- seqlen_k ,
1560- seqlen_q_rounded ,
1561- headdim ,
1562- BLOCK_HEADDIM = BLOCK_HEADDIM ,
1563- EVEN_M = EVEN_M ,
1564- EVEN_N = EVEN_N ,
1565- EVEN_HEADDIM = EVEN_HEADDIM ,
1566- BLOCK = BLOCK ,
1567- SEL_BLOCK = SEL_BLOCK ,
1568- QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1569- QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1570- SLIDING = SLIDING
1571- )
1532+ backward_kernel_one_col_block_causal (
1533+ start_n ,
1534+ Q ,
1535+ K ,
1536+ V ,
1537+ kv_block_indices ,
1538+ kv_block_mask ,
1539+ DO ,
1540+ DQ ,
1541+ DK ,
1542+ DV ,
1543+ LSE ,
1544+ D ,
1545+ softmax_scale ,
1546+ stride_qm ,
1547+ stride_kn ,
1548+ stride_vn ,
1549+ stride_dom ,
1550+ stride_dqm ,
1551+ stride_dkn ,
1552+ stride_dvn ,
1553+ stride_kvbl_m ,
1554+ stride_qh ,
1555+ stride_doh ,
1556+ stride_dqh ,
1557+ seqlen_q ,
1558+ seqlen_k ,
1559+ seqlen_q_rounded ,
1560+ headdim ,
1561+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
1562+ EVEN_M = EVEN_M ,
1563+ EVEN_N = EVEN_N ,
1564+ EVEN_HEADDIM = EVEN_HEADDIM ,
1565+ BLOCK = BLOCK ,
1566+ SEL_BLOCK = SEL_BLOCK ,
1567+ QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1568+ QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1569+ SLIDING = SLIDING
1570+ )
15721571 else :
1573- for start_n in range (0 , num_block_n ):
1574- backward_kernel_one_col_block_sparse (
1575- start_n ,
1576- Q ,
1577- K ,
1578- V ,
1579- kv_block_indices ,
1580- kv_block_mask ,
1581- kv_block_grads ,
1582- DO ,
1583- DQ ,
1584- DK ,
1585- DV ,
1586- LSE ,
1587- D ,
1588- softmax_scale ,
1589- stride_qm ,
1590- stride_kn ,
1591- stride_vn ,
1592- stride_dom ,
1593- stride_dqm ,
1594- stride_dkn ,
1595- stride_dvn ,
1596- stride_kvbl_m ,
1597- stride_qh ,
1598- stride_doh ,
1599- stride_dqh ,
1600- seqlen_q ,
1601- seqlen_k ,
1602- seqlen_q_rounded ,
1603- headdim ,
1604- BLOCK_HEADDIM = BLOCK_HEADDIM ,
1605- EVEN_M = EVEN_M ,
1606- EVEN_N = EVEN_N ,
1607- EVEN_HEADDIM = EVEN_HEADDIM ,
1608- BLOCK = BLOCK ,
1609- QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1610- QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1611- RETURN_SEL_GRADS = RETURN_SEL_GRADS ,
1612- OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS ,
1613- BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT ,
1614- BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT ,
1615- )
1572+ backward_kernel_one_col_block_sparse (
1573+ start_n ,
1574+ Q ,
1575+ K ,
1576+ V ,
1577+ kv_block_indices ,
1578+ kv_block_mask ,
1579+ kv_block_grads ,
1580+ DO ,
1581+ DQ ,
1582+ DK ,
1583+ DV ,
1584+ LSE ,
1585+ D ,
1586+ softmax_scale ,
1587+ stride_qm ,
1588+ stride_kn ,
1589+ stride_vn ,
1590+ stride_dom ,
1591+ stride_dqm ,
1592+ stride_dkn ,
1593+ stride_dvn ,
1594+ stride_kvbl_m ,
1595+ stride_qh ,
1596+ stride_doh ,
1597+ stride_dqh ,
1598+ seqlen_q ,
1599+ seqlen_k ,
1600+ seqlen_q_rounded ,
1601+ headdim ,
1602+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
1603+ EVEN_M = EVEN_M ,
1604+ EVEN_N = EVEN_N ,
1605+ EVEN_HEADDIM = EVEN_HEADDIM ,
1606+ BLOCK = BLOCK ,
1607+ QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
1608+ QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1609+ RETURN_SEL_GRADS = RETURN_SEL_GRADS ,
1610+ OFF_SEL_KV_BLOCKS = OFF_SEL_KV_BLOCKS ,
1611+ BLOCK_DV_USE_DOT = BLOCK_DV_USE_DOT ,
1612+ BLOCK_DK_USE_DOT = BLOCK_DK_USE_DOT ,
1613+ )
16161614
16171615def native_sparse_attn_backward (
16181616 do ,
@@ -1692,7 +1690,8 @@ def native_sparse_attn_backward(
16921690
16931691 grid = lambda META : (
16941692 num_sel_fine_blocks + int (include_block_causal ),
1695- batch * kv_heads
1693+ batch * kv_heads ,
1694+ triton .cdiv (seqlen_k , META ['BLOCK' ])
16961695 )
16971696
16981697 backward_kernel [grid ](
0 commit comments