@@ -106,7 +106,8 @@ def forward_kernel_causal_and_sparse(
106106 QUERY_HEAD_GROUPS : tl .constexpr ,
107107 QUERY_EXPAND_DIM : tl .constexpr ,
108108 NUM_SEL_KV_BLOCKS : tl .constexpr ,
109- INCLUDE_BLOCK_CAUSAL : tl .constexpr
109+ INCLUDE_BLOCK_CAUSAL : tl .constexpr ,
110+ SLIDING : tl .constexpr
110111):
111112 start_m = tl .program_id (0 )
112113 off_hb = tl .program_id (1 )
@@ -117,7 +118,6 @@ def forward_kernel_causal_and_sparse(
117118 offs_qh = off_h * QUERY_HEAD_GROUPS + tl .arange (0 , QUERY_HEAD_GROUPS )
118119
119120 offs_m = start_m * BLOCK + tl .arange (0 , BLOCK )
120- offs_n = start_m * BLOCK + tl .arange (0 , BLOCK )
121121 offs_d = tl .arange (0 , BLOCK_HEADDIM )
122122
123123 q_ptrs = (
@@ -181,6 +181,9 @@ def forward_kernel_causal_and_sparse(
181181 )
182182
183183 if INCLUDE_BLOCK_CAUSAL :
184+
185+ offs_n = start_m * BLOCK + tl .arange (0 , BLOCK )
186+
184187 k_ptrs = (
185188 K +
186189 off_b * stride_kb +
@@ -225,11 +228,21 @@ def forward_kernel_causal_and_sparse(
225228 qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
226229
227230 if not EVEN_N :
228- qk += tl .where (offs_n [None , :] < seqlen_k , 0 , float ("-inf" ))
231+ within_range_mask = offs_n [None , :] < seqlen_k
232+
233+ if SLIDING :
234+ within_range_mask &= offs_n [None , :] >= 0.
235+
236+ qk += tl .where (within_range_mask , 0 , float ("-inf" ))
229237
230238 qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
231239
232- qk += tl .where (offs_m [:, None , None ] >= offs_n [None , None , :], 0 , float ("-inf" ))
240+ causal_mask = offs_m [:, None , None ] >= offs_n [None , None , :]
241+
242+ if SLIDING :
243+ causal_mask &= (offs_n [None , None , :] - offs_m [:, None , None ]) <= BLOCK
244+
245+ qk += tl .where (causal_mask , 0 , float ("-inf" ))
233246
234247 m_ij = tl .maximum (tl .max (qk , 2 ) * softmax_scale , lse_i )
235248 p = tl .exp (qk * softmax_scale - m_ij [:, :, None ])
@@ -456,7 +469,8 @@ def forward_kernel(
456469 QUERY_HEAD_GROUPS : tl .constexpr ,
457470 QUERY_EXPAND_DIM : tl .constexpr ,
458471 NUM_SEL_KV_BLOCKS : tl .constexpr ,
459- INCLUDE_BLOCK_CAUSAL : tl .constexpr
472+ INCLUDE_BLOCK_CAUSAL : tl .constexpr ,
473+ SLIDING : tl .constexpr
460474):
461475 forward_kernel_causal_and_sparse (
462476 Q ,
@@ -498,7 +512,8 @@ def forward_kernel(
498512 QUERY_HEAD_GROUPS ,
499513 QUERY_EXPAND_DIM ,
500514 NUM_SEL_KV_BLOCKS ,
501- INCLUDE_BLOCK_CAUSAL
515+ INCLUDE_BLOCK_CAUSAL ,
516+ SLIDING
502517 )
503518
504519def native_sparse_attn_forward (
@@ -578,6 +593,7 @@ def native_sparse_attn_forward(
578593 QUERY_EXPAND_DIM = 16 // head_groups ,
579594 NUM_SEL_KV_BLOCKS = num_selected_fine_blocks ,
580595 INCLUDE_BLOCK_CAUSAL = include_block_causal ,
596+ SLIDING = False ,
581597 num_warps = num_warps ,
582598 num_stages = 1 ,
583599 )
@@ -626,8 +642,8 @@ def backward_preprocess_do_o_dot(
626642 + off_h * stride_doh
627643 + offs_m [:, None ] * stride_dom
628644 + offs_d [None , :],
629- mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
630- other = 0.0 ,
645+ mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
646+ other = 0.0 ,
631647 ).to (tl .float32 )
632648
633649 delta = tl .sum (o * do , axis = 1 )
0 commit comments