@@ -180,112 +180,121 @@ def forward_kernel_causal_and_sparse(
180180 other = 0.0
181181 )
182182
183+ q = q .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
184+
183185 if INCLUDE_BLOCK_CAUSAL :
184186
185- offs_n = start_m * BLOCK + tl .arange (0 , BLOCK )
187+ if SLIDING :
188+ num_kv_blocks = 2
189+ offset = - BLOCK
190+ else :
191+ num_kv_blocks = 1
192+ offset = 0
186193
187- k_ptrs = (
188- K +
189- off_b * stride_kb +
190- off_h * stride_kh +
191- offs_n [:, None ] * stride_kn +
192- offs_d [None , :]
193- )
194+ offs_n = start_m * BLOCK + tl .arange (0 , BLOCK ) + offset
194195
195- v_ptrs = (
196- V +
197- off_b * stride_vb +
198- off_h * stride_vh +
199- offs_n [:, None ] * stride_vn +
200- offs_d [None , :]
201- )
196+ for _ in range (num_kv_blocks ):
202197
203- if EVEN_N & EVEN_M :
204- if EVEN_HEADDIM :
205- k = tl .load (k_ptrs )
206- else :
207- k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
208- else :
209- if EVEN_HEADDIM :
210- k = tl .load (
211- k_ptrs ,
212- mask = offs_n [:, None ] < seqlen_k ,
213- other = 0.0 ,
214- )
215- else :
216- k = tl .load (
217- k_ptrs ,
218- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
219- other = 0.0 ,
220- )
198+ k_ptrs = (
199+ K +
200+ off_b * stride_kb +
201+ off_h * stride_kh +
202+ offs_n [:, None ] * stride_kn +
203+ offs_d [None , :]
204+ )
221205
222- qk = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
206+ v_ptrs = (
207+ V +
208+ off_b * stride_vb +
209+ off_h * stride_vh +
210+ offs_n [:, None ] * stride_vn +
211+ offs_d [None , :]
212+ )
223213
224- q = q .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
214+ if EVEN_N & EVEN_M :
215+ if EVEN_HEADDIM :
216+ k = tl .load (k_ptrs )
217+ else :
218+ k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
219+ else :
220+ if EVEN_HEADDIM :
221+ k = tl .load (
222+ k_ptrs ,
223+ mask = offs_n [:, None ] < seqlen_k ,
224+ other = 0.0 ,
225+ )
226+ else :
227+ k = tl .load (
228+ k_ptrs ,
229+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
230+ other = 0.0 ,
231+ )
225232
226- qk + = tl .dot ( q , tl .trans ( k ) )
233+ qk = tl .zeros ([ BLOCK * QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
227234
228- qk = qk . reshape ( BLOCK , QUERY_HEAD_GROUPS , BLOCK )
235+ qk += tl . dot ( q , tl . trans ( k ) )
229236
230- if not EVEN_N :
231- within_range_mask = offs_n [None , :] < seqlen_k
237+ qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
232238
233- if SLIDING :
234- within_range_mask & = offs_n [None , :] >= 0.
239+ if not EVEN_N :
240+ within_range_mask = offs_n [None , :] < seqlen_k
235241
236- qk += tl .where (within_range_mask , 0 , float ("-inf" ))
242+ if SLIDING :
243+ within_range_mask &= offs_n [None , :] >= 0.
237244
238- qk = qk . reshape ( BLOCK , QUERY_HEAD_GROUPS , BLOCK )
245+ qk += tl . where ( within_range_mask , 0 , float ( "-inf" ) )
239246
240- causal_mask = offs_m [:, None , None ] >= offs_n [ None , None , :]
247+ qk = qk . reshape ( BLOCK , QUERY_HEAD_GROUPS , BLOCK )
241248
242- if SLIDING :
243- causal_mask &= (offs_n [None , None , :] - offs_m [:, None , None ]) <= BLOCK
249+ causal_mask = offs_m [:, None , None ] >= offs_n [None , None , :]
244250
245- qk += tl .where (causal_mask , 0 , float ("-inf" ))
251+ if SLIDING :
252+ causal_mask &= (offs_n [None , None , :] - offs_m [:, None , None ]) <= BLOCK
246253
247- m_ij = tl .maximum (tl .max (qk , 2 ) * softmax_scale , lse_i )
248- p = tl .exp (qk * softmax_scale - m_ij [:, :, None ])
254+ qk += tl .where (causal_mask , 0 , float ("-inf" ))
249255
250- l_ij = tl .sum (p , 2 )
256+ m_ij = tl .maximum (tl .max (qk , 2 ) * softmax_scale , lse_i )
257+ p = tl .exp (qk * softmax_scale - m_ij [:, :, None ])
251258
252- acc_o_scale = tl .exp (m_i - m_ij )
253- acc_o *= acc_o_scale [:, :, None ]
259+ l_ij = tl .sum (p , 2 )
254260
255- if EVEN_N & EVEN_M :
256- if EVEN_HEADDIM :
257- v = tl .load (v_ptrs )
258- else :
259- v = tl .load (
260- v_ptrs ,
261- mask = offs_d [None , :] < headdim ,
262- other = 0.0
263- )
264- else :
265- if EVEN_HEADDIM :
266- v = tl .load (
267- v_ptrs ,
268- mask = offs_n [:, None ] < seqlen_k ,
269- other = 0.0 ,
270- )
261+ acc_o_scale = tl .exp (m_i - m_ij )
262+ acc_o *= acc_o_scale [:, :, None ]
263+
264+ if EVEN_N & EVEN_M :
265+ if EVEN_HEADDIM :
266+ v = tl .load (v_ptrs )
267+ else :
268+ v = tl .load (
269+ v_ptrs ,
270+ mask = offs_d [None , :] < headdim ,
271+ other = 0.0
272+ )
271273 else :
272- v = tl .load (
273- v_ptrs ,
274- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
275- other = 0.0 ,
276- )
274+ if EVEN_HEADDIM :
275+ v = tl .load (
276+ v_ptrs ,
277+ mask = offs_n [:, None ] < seqlen_k ,
278+ other = 0.0 ,
279+ )
280+ else :
281+ v = tl .load (
282+ v_ptrs ,
283+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
284+ other = 0.0 ,
285+ )
277286
278- p = p .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK ).to (v .dtype )
287+ p = p .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK ).to (v .dtype )
279288
280- causal_o = tl .dot (p , v )
289+ causal_o = tl .dot (p , v )
281290
282- acc_o += causal_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
291+ acc_o += causal_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
283292
284- # -- update statistics
293+ # -- update statistics
285294
286- m_i = m_ij
287- l_i_new = tl .exp (lse_i - m_ij ) + l_ij
288- lse_i = m_ij + tl .log (l_i_new )
295+ m_i = m_ij
296+ l_i_new = tl .exp (lse_i - m_ij ) + l_ij
297+ lse_i = m_ij + tl .log (l_i_new )
289298
290299 # # take care of the selected kv blocks
291300
@@ -1029,6 +1038,7 @@ def backward_kernel_one_col_block_causal(
10291038 BLOCK : tl .constexpr ,
10301039 QUERY_HEAD_GROUPS : tl .constexpr ,
10311040 QUERY_EXPAND_DIM : tl .constexpr ,
1041+ SLIDING : tl .constexpr
10321042):
10331043 # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
10341044
@@ -1143,11 +1153,16 @@ def backward_kernel_one_col_block_causal(
11431153
11441154 qk = qk .reshape (QUERY_HEAD_GROUPS , BLOCK , BLOCK )
11451155
1156+ mask = offs_m [:, None ] >= offs_n [None , :]
1157+
11461158 # Trying to combine the two masks seem to make the result wrong
11471159 if not EVEN_N : # Need to mask out otherwise the softmax is wrong
1148- qk = tl .where (offs_n [None , :] < seqlen_k , qk , float ("-inf" ))
1160+ mask &= offs_n [None , :] < seqlen_k
1161+
1162+ if SLIDING :
1163+ mask &= (offs_n [None , :] - offs_m [:, None ]) < BLOCK
11491164
1150- qk = tl .where (offs_m [:, None ] >= ( offs_n [ None , :]) , qk , float ("-inf" ))
1165+ qk = tl .where (mask , qk , float ("-inf" ))
11511166
11521167 qk = qk .reshape (QUERY_HEAD_GROUPS * BLOCK , BLOCK )
11531168
@@ -1315,7 +1330,8 @@ def backward_kernel(
13151330 QUERY_HEAD_GROUPS : tl .constexpr ,
13161331 QUERY_EXPAND_DIM : tl .constexpr ,
13171332 RETURN_SEL_GRADS : tl .constexpr ,
1318- INCLUDE_BLOCK_CAUSAL : tl .constexpr
1333+ INCLUDE_BLOCK_CAUSAL : tl .constexpr ,
1334+ SLIDING : tl .constexpr ,
13191335):
13201336 off_hb = tl .program_id (1 )
13211337 off_b = off_hb // kv_heads
@@ -1393,6 +1409,7 @@ def backward_kernel(
13931409 BLOCK = BLOCK ,
13941410 QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS ,
13951411 QUERY_EXPAND_DIM = QUERY_EXPAND_DIM ,
1412+ SLIDING = SLIDING
13961413 )
13971414 else :
13981415 for start_n in range (0 , num_block_n ):
@@ -1448,7 +1465,8 @@ def native_sparse_attn_backward(
14481465 dq , dk , dv ,
14491466 block_size = 128 ,
14501467 include_block_causal = True ,
1451- return_sel_grads = False
1468+ return_sel_grads = False ,
1469+ sliding = False
14521470):
14531471 device = do .device
14541472
@@ -1563,6 +1581,7 @@ def native_sparse_attn_backward(
15631581 EVEN_HEADDIM = BLOCK_HEADDIM == dim ,
15641582 RETURN_SEL_GRADS = return_sel_grads ,
15651583 INCLUDE_BLOCK_CAUSAL = include_block_causal ,
1584+ SLIDING = sliding
15661585 # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
15671586 # num_warps=num_warps,
15681587 # num_stages=1,
@@ -1600,7 +1619,7 @@ def forward(
16001619 selected_block_indices ,
16011620 fmask ,
16021621 block_size = block_size ,
1603- include_block_causal = include_block_causal
1622+ include_block_causal = include_block_causal ,
16041623 )
16051624
16061625 ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
0 commit comments