@@ -82,7 +82,8 @@ def forward_kernel(
8282 stride_kvbl_b ,
8383 stride_kvbl_h ,
8484 stride_kvbl_m ,
85- nheads ,
85+ stride_lse_b ,
86+ kv_heads ,
8687 seqlen_q ,
8788 seqlen_k ,
8889 seqlen_q_rounded ,
@@ -100,9 +101,9 @@ def forward_kernel(
100101):
101102 start_m = tl .program_id (0 )
102103 off_hb = tl .program_id (1 )
103- off_b = off_hb // nheads
104104
105- off_h = off_hb % nheads
105+ off_b = off_hb // kv_heads
106+ off_h = off_hb % kv_heads
106107
107108 offs_qh = off_h * QUERY_HEAD_GROUPS + tl .arange (0 , QUERY_HEAD_GROUPS )
108109
@@ -142,6 +143,7 @@ def forward_kernel(
142143
143144 lse_ptrs = (
144145 Lse +
146+ off_b * stride_lse_b +
145147 offs_qh [:, None ] * seqlen_q_rounded +
146148 offs_m [None , :]
147149 )
@@ -348,7 +350,7 @@ def forward_kernel(
348350 # write back lse
349351
350352 lse_i = lse_i .reshape ([QUERY_HEAD_GROUPS , BLOCK ])
351- tl .store (lse_ptrs , lse_i )
353+ tl .store (lse_ptrs , lse_i , mask = offs_m [ None , :] < seqlen_q )
352354
353355 # write to output
354356
@@ -429,7 +431,8 @@ def native_sparse_attn_forward(
429431 kv_block_indices .stride (0 ),
430432 kv_block_indices .stride (1 ),
431433 kv_block_indices .stride (2 ),
432- nheads ,
434+ lse .stride (0 ),
435+ kv_heads ,
433436 seqlen_q ,
434437 seqlen_k ,
435438 seqlen_q_rounded ,
@@ -591,7 +594,6 @@ def backward_kernel_one_col_block(
591594 k_ptrs = K + (offs_n [:, None ] * stride_kn + offs_d [None , :])
592595 v_ptrs = V + (offs_n [:, None ] * stride_vn + offs_d [None , :])
593596
594-
595597 q_ptrs = (
596598 Q +
597599 offs_g [:, None , None ] * stride_qh +
@@ -956,6 +958,8 @@ def backward_kernel(
956958 stride_kvbl_b ,
957959 stride_kvbl_h ,
958960 stride_kvbl_m ,
961+ stride_lse_b ,
962+ stride_D_b ,
959963 kv_heads ,
960964 seqlen_q ,
961965 seqlen_k ,
@@ -993,8 +997,16 @@ def backward_kernel(
993997 kv_block_mask += off_b * stride_kvbl_b + off_h * stride_kvbl_h
994998
995999 # pointer to row-wise quantities in value-like data
996- D += off_hb * QUERY_HEAD_GROUPS * seqlen_q_rounded
997- LSE += off_hb * QUERY_HEAD_GROUPS * seqlen_q_rounded
1000+
1001+ D += (
1002+ off_b * stride_D_b +
1003+ off_h * QUERY_HEAD_GROUPS * seqlen_q_rounded
1004+ )
1005+
1006+ LSE += (
1007+ off_b * stride_lse_b +
1008+ off_h * QUERY_HEAD_GROUPS * seqlen_q_rounded
1009+ )
9981010
9991011 num_block_n = tl .cdiv (seqlen_k , BLOCK )
10001012 for start_n in range (0 , num_block_n ):
@@ -1137,6 +1149,8 @@ def native_sparse_attn_backward(
11371149 kv_block_indices .stride (0 ),
11381150 kv_block_indices .stride (1 ),
11391151 kv_block_indices .stride (2 ),
1152+ lse .stride (0 ),
1153+ delta .stride (0 ),
11401154 kv_heads ,
11411155 seqlen_q ,
11421156 seqlen_k ,
0 commit comments