@@ -34,7 +34,7 @@ def is_contiguous(x: Tensor):
3434from importlib .metadata import version
3535
3636try :
37- triton_version = version ('triton-nightly ' )
37+ triton_version = version ('triton' )
3838except :
3939 print (f'latest triton must be installed. `{ INSTALL_COMMAND } ` first' )
4040 exit ()
@@ -216,17 +216,91 @@ def _fwd_kernel(
216216 l_i_new = tl .exp (lse_i - m_ij ) + l_ij
217217 lse_i = m_ij + tl .log (l_i_new )
218218
219+ # take care of the selected kv blocks
220+
221+ kv_block_indices_ptrs = (
222+ KV_block_indices +
223+ off_b * stride_kvbl_b +
224+ off_h * stride_kvbl_h +
225+ offs_m * stride_kvbl_m
226+ )
227+
228+ kv_block_mask_ptrs = (
229+ KV_block_mask +
230+ off_b * stride_kvbl_b +
231+ off_h * stride_kvbl_h +
232+ offs_m * stride_kvbl_m
233+ )
234+
235+ for off_sel_kv_block in range (NUM_SEL_KV_BLOCKS ):
236+ block_indices = tl .load (kv_block_indices_ptrs + off_sel_kv_block )
237+ block_masks = tl .load (kv_block_mask_ptrs + off_sel_kv_block )
238+
239+ blocks_offs_n = block_indices [:, None ] * BLOCK + tl .arange (0 , BLOCK )[None , :]
240+
241+ block_k_ptrs = (
242+ K + off_b * stride_kb + off_h * stride_kh + (blocks_offs_n [:, :, None ] * stride_kn + offs_d [None , None , :])
243+ )
244+
245+ block_v_ptrs = (
246+ V + off_b * stride_vb + off_h * stride_vh + (blocks_offs_n [:, :, None ] * stride_vn + offs_d [None , None , :])
247+ )
248+
249+ # load k of shape (m, n, d), sparsely selected by each query
250+
251+ k_block = tl .load (block_k_ptrs )
252+
253+ # similarities
254+
255+ block_qk = tl .zeros ([BLOCK , 16 , BLOCK ], dtype = tl .float32 )
256+ qk = tl .zeros ([BLOCK , BLOCK ], dtype = tl .float32 )
257+
258+ k_block = tl .reshape (k_block , (BLOCK , BLOCK , BLOCK_HEADDIM ))
259+ k_block = tl .permute (k_block , (0 , 2 , 1 ))
260+
261+ q_expanded = tl .expand_dims (q , 1 )
262+ q_expanded = tl .broadcast_to (q_expanded , (BLOCK , 16 , BLOCK_HEADDIM ))
263+
264+ block_qk = tl .dot (q_expanded , k_block )
265+ qk += tl .sum (block_qk , 1 ) / 16.
266+ qk += tl .where (block_masks [:, None ], 0 , float ("-inf" ))
267+
268+ m_ij = tl .maximum (tl .max (qk , 1 ) * softmax_scale , lse_i )
269+ p = tl .exp (qk * softmax_scale - m_ij [:, None ])
270+
271+ l_ij = tl .sum (p , 1 )
272+
273+ acc_o_scale = tl .exp (m_i - m_ij )
274+ acc_o = acc_o * acc_o_scale [:, None ]
275+
276+ v_block = tl .load (block_v_ptrs )
277+ v_block = tl .reshape (v_block , (BLOCK , BLOCK , BLOCK_HEADDIM ))
278+
279+ p = p .to (v_block .dtype )
280+ p_expanded = tl .expand_dims (p , 1 )
281+ p_expanded = tl .broadcast_to (p_expanded , (BLOCK , 16 , BLOCK ))
282+
283+ block_acc_o = tl .dot (p_expanded , v_block )
284+ block_acc_o = tl .sum (block_acc_o , 1 ) / 16.
285+ acc_o += block_acc_o
286+
287+ # -- update statistics
288+
289+ m_i = m_ij
290+ l_i_new = tl .exp (lse_i - m_ij ) + l_ij
291+ lse_i = m_ij + tl .log (l_i_new )
292+
219293 # normalize accumulated out
220294
221295 acc_o_scale = tl .exp (m_i - lse_i )
222296 acc_o = acc_o * acc_o_scale [:, None ]
223297
224- # offsets for m and lse
298+ # offsets
225299
226300 start_m = tl .program_id (0 )
227301 offs_m = start_m * BLOCK + tl .arange (0 , BLOCK )
228302
229- # write back lse and m
303+ # write back lse
230304
231305 tl .store (lse_ptrs , lse_i )
232306
@@ -253,7 +327,7 @@ def flash_attn_forward(
253327 kv_block_mask ,
254328 block_size = 128
255329):
256- q , k , v = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v )]
330+ q , k , v , kv_block_indices = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v , kv_block_indices )]
257331
258332 batch , seqlen_q , nheads , dim = q .shape
259333 _ , seqlen_k , _ , _ = k .shape
@@ -266,7 +340,7 @@ def flash_attn_forward(
266340 assert dim <= 128 , "only support head dimensions up to 128"
267341 assert q .dtype == k .dtype == v .dtype , "All tensors must have the same type"
268342 assert q .dtype in [torch .float16 , torch .bfloat16 ], "Only support fp16 and bf16"
269- assert q .is_cuda and k . is_cuda and v . is_cuda
343+ assert all ([ t .is_cuda for t in ( q , k , v )])
270344
271345 softmax_scale = dim ** - 0.5
272346
0 commit comments