@@ -112,7 +112,8 @@ def forward_kernel(
112112 BLOCK : tl .constexpr ,
113113 QUERY_HEAD_GROUPS : tl .constexpr ,
114114 QUERY_EXPAND_DIM : tl .constexpr ,
115- NUM_SEL_KV_BLOCKS : tl .constexpr
115+ NUM_SEL_KV_BLOCKS : tl .constexpr ,
116+ INCLUDE_BLOCK_CAUSAL : tl .constexpr
116117):
117118 start_m = tl .program_id (0 )
118119 off_hb = tl .program_id (1 )
@@ -134,22 +135,6 @@ def forward_kernel(
134135 offs_d [None , None , :]
135136 )
136137
137- k_ptrs = (
138- K +
139- off_b * stride_kb +
140- off_h * stride_kh +
141- offs_n [:, None ] * stride_kn +
142- offs_d [None , :]
143- )
144-
145- v_ptrs = (
146- V +
147- off_b * stride_vb +
148- off_h * stride_vh +
149- offs_n [:, None ] * stride_vn +
150- offs_d [None , :]
151- )
152-
153138 # maximum
154139
155140 m_i = tl .zeros ([BLOCK , QUERY_HEAD_GROUPS ], dtype = tl .float32 ) - float ("inf" )
@@ -202,82 +187,99 @@ def forward_kernel(
202187 other = 0.0
203188 )
204189
205- if EVEN_N & EVEN_M :
206- if EVEN_HEADDIM :
207- k = tl .load (k_ptrs )
208- else :
209- k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
210- else :
211- if EVEN_HEADDIM :
212- k = tl .load (
213- k_ptrs ,
214- mask = offs_n [:, None ] < seqlen_k ,
215- other = 0.0 ,
216- )
217- else :
218- k = tl .load (
219- k_ptrs ,
220- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
221- other = 0.0 ,
222- )
190+ if INCLUDE_BLOCK_CAUSAL :
191+ k_ptrs = (
192+ K +
193+ off_b * stride_kb +
194+ off_h * stride_kh +
195+ offs_n [:, None ] * stride_kn +
196+ offs_d [None , :]
197+ )
223198
224- qk = tl .zeros ([BLOCK * QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
199+ v_ptrs = (
200+ V +
201+ off_b * stride_vb +
202+ off_h * stride_vh +
203+ offs_n [:, None ] * stride_vn +
204+ offs_d [None , :]
205+ )
225206
226- q = q .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
207+ if EVEN_N & EVEN_M :
208+ if EVEN_HEADDIM :
209+ k = tl .load (k_ptrs )
210+ else :
211+ k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
212+ else :
213+ if EVEN_HEADDIM :
214+ k = tl .load (
215+ k_ptrs ,
216+ mask = offs_n [:, None ] < seqlen_k ,
217+ other = 0.0 ,
218+ )
219+ else :
220+ k = tl .load (
221+ k_ptrs ,
222+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
223+ other = 0.0 ,
224+ )
227225
228- qk + = tl .dot ( q , tl .trans ( k ) )
226+ qk = tl .zeros ([ BLOCK * QUERY_HEAD_GROUPS , BLOCK ], dtype = tl .float32 )
229227
230- qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
228+ q = q .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
231229
232- if not EVEN_N :
233- qk += tl .where (offs_n [None , :] < seqlen_k , 0 , float ("-inf" ))
230+ qk += tl .dot (q , tl .trans (k ))
234231
235- qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
232+ qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
236233
237- qk += tl .where (offs_m [:, None , None ] >= offs_n [None , None , :], 0 , float ("-inf" ))
234+ if not EVEN_N :
235+ qk += tl .where (offs_n [None , :] < seqlen_k , 0 , float ("-inf" ))
238236
239- m_ij = tl .maximum (tl .max (qk , 2 ) * softmax_scale , lse_i )
240- p = tl .exp (qk * softmax_scale - m_ij [:, :, None ])
237+ qk = qk .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK )
241238
242- l_ij = tl .sum ( p , 2 )
239+ qk + = tl .where ( offs_m [:, None , None ] >= offs_n [ None , None , :], 0 , float ( "-inf" ) )
243240
244- acc_o_scale = tl .exp ( m_i - m_ij )
245- acc_o *= acc_o_scale [:, :, None ]
241+ m_ij = tl .maximum ( tl . max ( qk , 2 ) * softmax_scale , lse_i )
242+ p = tl . exp ( qk * softmax_scale - m_ij [:, :, None ])
246243
247- if EVEN_N & EVEN_M :
248- if EVEN_HEADDIM :
249- v = tl .load (v_ptrs )
250- else :
251- v = tl .load (
252- v_ptrs ,
253- mask = offs_d [None , :] < headdim ,
254- other = 0.0
255- )
256- else :
257- if EVEN_HEADDIM :
258- v = tl .load (
259- v_ptrs ,
260- mask = offs_n [:, None ] < seqlen_k ,
261- other = 0.0 ,
262- )
244+ l_ij = tl .sum (p , 2 )
245+
246+ acc_o_scale = tl .exp (m_i - m_ij )
247+ acc_o *= acc_o_scale [:, :, None ]
248+
249+ if EVEN_N & EVEN_M :
250+ if EVEN_HEADDIM :
251+ v = tl .load (v_ptrs )
252+ else :
253+ v = tl .load (
254+ v_ptrs ,
255+ mask = offs_d [None , :] < headdim ,
256+ other = 0.0
257+ )
263258 else :
264- v = tl .load (
265- v_ptrs ,
266- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
267- other = 0.0 ,
268- )
259+ if EVEN_HEADDIM :
260+ v = tl .load (
261+ v_ptrs ,
262+ mask = offs_n [:, None ] < seqlen_k ,
263+ other = 0.0 ,
264+ )
265+ else :
266+ v = tl .load (
267+ v_ptrs ,
268+ mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
269+ other = 0.0 ,
270+ )
269271
270- p = p .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK ).to (v .dtype )
272+ p = p .reshape (BLOCK * QUERY_HEAD_GROUPS , BLOCK ).to (v .dtype )
271273
272- causal_o = tl .dot (p , v )
274+ causal_o = tl .dot (p , v )
273275
274- acc_o += causal_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
276+ acc_o += causal_o .reshape (BLOCK , QUERY_HEAD_GROUPS , BLOCK_HEADDIM )
275277
276- # -- update statistics
278+ # -- update statistics
277279
278- m_i = m_ij
279- l_i_new = tl .exp (lse_i - m_ij ) + l_ij
280- lse_i = m_ij + tl .log (l_i_new )
280+ m_i = m_ij
281+ l_i_new = tl .exp (lse_i - m_ij ) + l_ij
282+ lse_i = m_ij + tl .log (l_i_new )
281283
282284 # # take care of the selected kv blocks
283285
@@ -419,7 +421,8 @@ def native_sparse_attn_forward(
419421 v ,
420422 kv_block_indices ,
421423 kv_block_mask ,
422- block_size = 128
424+ block_size = 128 ,
425+ include_block_causal = True
423426):
424427 q , k , v , kv_block_indices = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v , kv_block_indices )]
425428
@@ -488,6 +491,7 @@ def native_sparse_attn_forward(
488491 QUERY_HEAD_GROUPS = head_groups ,
489492 QUERY_EXPAND_DIM = 16 // head_groups ,
490493 NUM_SEL_KV_BLOCKS = num_selected_fine_blocks ,
494+ INCLUDE_BLOCK_CAUSAL = include_block_causal ,
491495 num_warps = num_warps ,
492496 num_stages = 1 ,
493497 )
@@ -1184,14 +1188,19 @@ def backward_kernel(
11841188 BLOCK : tl .constexpr ,
11851189 QUERY_HEAD_GROUPS : tl .constexpr ,
11861190 QUERY_EXPAND_DIM : tl .constexpr ,
1191+ INCLUDE_BLOCK_CAUSAL : tl .constexpr
11871192):
11881193 off_hb = tl .program_id (1 )
11891194 off_b = off_hb // kv_heads
11901195 off_h = off_hb % kv_heads
11911196 off_qh = off_h * QUERY_HEAD_GROUPS
11921197
1193- IS_CAUSAL = tl .program_id (0 ) == 0
1194- OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - 1
1198+ if INCLUDE_BLOCK_CAUSAL :
1199+ IS_CAUSAL = tl .program_id (0 ) == 0
1200+ OFF_SEL_KV_BLOCKS = tl .program_id (0 ) - 1
1201+ else :
1202+ IS_CAUSAL = False
1203+ OFF_SEL_KV_BLOCKS = tl .program_id (0 )
11951204
11961205 # offset pointers for batch/head
11971206
@@ -1310,7 +1319,8 @@ def native_sparse_attn_backward(
13101319 o ,
13111320 lse ,
13121321 dq , dk , dv ,
1313- block_size = 128
1322+ block_size = 128 ,
1323+ include_block_causal = True
13141324):
13151325 device = do .device
13161326
@@ -1362,7 +1372,10 @@ def native_sparse_attn_backward(
13621372 BLOCK_HEADDIM = BLOCK_HEADDIM ,
13631373 )
13641374
1365- grid = lambda META : (num_sel_fine_blocks + 1 , batch * kv_heads )
1375+ grid = lambda META : (
1376+ num_sel_fine_blocks + int (include_block_causal ),
1377+ batch * kv_heads
1378+ )
13661379
13671380 backward_kernel [grid ](
13681381 q ,
@@ -1418,7 +1431,8 @@ def native_sparse_attn_backward(
14181431 QUERY_EXPAND_DIM = 16 // head_groups ,
14191432 EVEN_M = divisible_by (seqlen_q , block_size ),
14201433 EVEN_N = divisible_by (seqlen_k , block_size ),
1421- EVEN_HEADDIM = BLOCK_HEADDIM == dim
1434+ EVEN_HEADDIM = BLOCK_HEADDIM == dim ,
1435+ INCLUDE_BLOCK_CAUSAL = include_block_causal ,
14221436 # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
14231437 # num_warps=num_warps,
14241438 # num_stages=1,
@@ -1440,6 +1454,7 @@ def forward(
14401454 block_size ,
14411455 selected_block_indices ,
14421456 fmask ,
1457+ include_block_causal
14431458 ):
14441459 dtype = fq .dtype
14451460
@@ -1453,14 +1468,16 @@ def forward(
14531468 fq , fk , fv ,
14541469 selected_block_indices ,
14551470 fmask ,
1456- block_size = block_size
1471+ block_size = block_size ,
1472+ include_block_causal = include_block_causal
14571473 )
14581474
14591475 ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
14601476
14611477 ctx ._saved_variables = (
14621478 block_size ,
1463- head_groups
1479+ head_groups ,
1480+ include_block_causal
14641481 )
14651482
14661483 return out .type (dtype ), lse
@@ -1473,7 +1490,8 @@ def backward(self, ctx, do, _):
14731490
14741491 (
14751492 block_size ,
1476- head_groups
1493+ head_groups ,
1494+ include_block_causal
14771495 ) = ctx ._saved_variables
14781496
14791497 do = do .half ()
@@ -1485,7 +1503,8 @@ def backward(self, ctx, do, _):
14851503 do , q , k , v ,
14861504 sel_block_indices , mask ,
14871505 out , lse , dq , dk , dv ,
1488- block_size = block_size
1506+ block_size = block_size ,
1507+ include_block_causal = include_block_causal
14891508 )
14901509
14911510 return dq , dk , dv , None , None , None , None
@@ -1508,6 +1527,7 @@ def native_sparse_attend(
15081527 block_size : int ,
15091528 selected_block_indices : Int ['b qh n sel' ] | Int ['b kh n sel' ],
15101529 fmask : Bool ['b qh n sel' ] | Bool ['b kh n sel' ],
1530+ include_block_causal = True ,
15111531 return_lse = False
15121532):
15131533 seq_len = fq .shape [- 2 ]
@@ -1526,6 +1546,7 @@ def native_sparse_attend(
15261546 block_size ,
15271547 selected_block_indices ,
15281548 fmask ,
1549+ include_block_causal
15291550 )
15301551
15311552 if not return_lse :
0 commit comments