@@ -16,6 +16,9 @@ def exists(v):
1616def default (val , d ):
1717 return val if exists (val ) else d
1818
19+ def divisible_by (num , den ):
20+ return (num % den ) == 0
21+
1922def round_up_multiple (n , mult ):
2023 return ceil (n / mult ) * mult
2124
@@ -49,8 +52,8 @@ def is_contiguous(x: Tensor):
4952
5053@triton .heuristics (
5154 {
52- "EVEN_M" : lambda args : args ["seqlen_q" ] % args ["BLOCK" ] == 0 ,
53- "EVEN_N" : lambda args : args ["seqlen_k" ] % args ["BLOCK" ] == 0 ,
55+ "EVEN_M" : lambda args : divisible_by ( args ["seqlen_q" ], args ["BLOCK" ]) ,
56+ "EVEN_N" : lambda args : divisible_by ( args ["seqlen_k" ], args ["BLOCK" ]) ,
5457 "EVEN_HEADDIM" : lambda args : args ["headdim" ] == args ["BLOCK_HEADDIM" ],
5558 }
5659)
@@ -335,14 +338,14 @@ def flash_attn_forward(
335338):
336339 q , k , v , kv_block_indices = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v , kv_block_indices )]
337340
338- batch , seqlen_q , nheads , dim = q .shape
339- _ , seqlen_k , _ , _ = k .shape
341+ batch , nheads , seqlen_q , dim , device = * q .shape , q . device
342+ _ , _ , seqlen_k , _ = k .shape
340343
341344 num_selected_fine_blocks = kv_block_indices .shape [- 1 ]
342345 assert kv_block_indices .shape == kv_block_mask .shape
343346
344- assert k .shape == (batch , seqlen_k , nheads , dim )
345- assert v .shape == (batch , seqlen_k , nheads , dim )
347+ assert k .shape == (batch , nheads , seqlen_k , dim )
348+ assert v .shape == (batch , nheads , seqlen_k , dim )
346349 assert dim <= 128 , "only support head dimensions up to 128"
347350 assert q .dtype == k .dtype == v .dtype , "All tensors must have the same type"
348351 assert q .dtype in [torch .float16 , torch .bfloat16 ], "Only support fp16 and bf16"
@@ -352,9 +355,9 @@ def flash_attn_forward(
352355
353356 seqlen_q_rounded = round_up_multiple (seqlen_q , TRITON_BLOCK_SIZE )
354357
355- lse = torch .empty ((batch , nheads , seqlen_q_rounded ), device = q . device , dtype = torch .float32 )
358+ lse = torch .empty ((batch , nheads , seqlen_q_rounded ), device = device , dtype = torch .float32 )
356359
357- m = torch .empty ((batch , nheads , seqlen_q_rounded ), device = q . device , dtype = torch .float32 )
360+ m = torch .empty ((batch , nheads , seqlen_q_rounded ), device = device , dtype = torch .float32 )
358361
359362 o = torch .empty_like (q )
360363
@@ -373,20 +376,20 @@ def flash_attn_forward(
373376 lse ,
374377 softmax_scale ,
375378 q .stride (0 ),
376- q .stride (2 ),
377379 q .stride (1 ),
380+ q .stride (2 ),
378381 k .stride (0 ),
379- k .stride (2 ),
380382 k .stride (1 ),
383+ k .stride (2 ),
381384 v .stride (0 ),
382- v .stride (2 ),
383385 v .stride (1 ),
386+ v .stride (2 ),
384387 o .stride (0 ),
385- o .stride (2 ),
386388 o .stride (1 ),
389+ o .stride (2 ),
387390 kv_block_indices .stride (0 ),
388- kv_block_indices .stride (2 ),
389391 kv_block_indices .stride (1 ),
392+ kv_block_indices .stride (2 ),
390393 nheads ,
391394 seqlen_q ,
392395 seqlen_k ,
@@ -964,8 +967,8 @@ def flash_attn_backward(
964967 if not is_contiguous (do ):
965968 do = do .contiguous ()
966969
967- batch , seqlen_q , nheads , dim = q .shape
968- _ , seqlen_k , _ , _ = k .shape
970+ batch , nheads , seqlen_q , dim = q .shape
971+ _ , _ , seqlen_k , _ = k .shape
969972
970973 num_sel_fine_blocks = kv_block_indices .shape [- 1 ]
971974 assert kv_block_indices .shape == kv_block_mask .shape
@@ -995,11 +998,11 @@ def flash_attn_backward(
995998 do ,
996999 delta ,
9971000 o .stride (0 ),
998- o .stride (2 ),
9991001 o .stride (1 ),
1002+ o .stride (2 ),
10001003 do .stride (0 ),
1001- do .stride (2 ),
10021004 do .stride (1 ),
1005+ do .stride (2 ),
10031006 nheads ,
10041007 seqlen_q ,
10051008 seqlen_q_rounded ,
@@ -1027,29 +1030,29 @@ def flash_attn_backward(
10271030 delta ,
10281031 softmax_scale ,
10291032 q .stride (0 ),
1030- q .stride (2 ),
10311033 q .stride (1 ),
1034+ q .stride (2 ),
10321035 k .stride (0 ),
1033- k .stride (2 ),
10341036 k .stride (1 ),
1037+ k .stride (2 ),
10351038 v .stride (0 ),
1036- v .stride (2 ),
10371039 v .stride (1 ),
1040+ v .stride (2 ),
10381041 do .stride (0 ),
1039- do .stride (2 ),
10401042 do .stride (1 ),
1043+ do .stride (2 ),
10411044 dq_accum .stride (0 ),
1042- dq_accum .stride (2 ),
10431045 dq_accum .stride (1 ),
1046+ dq_accum .stride (2 ),
10441047 dk .stride (0 ),
1045- dk .stride (2 ),
10461048 dk .stride (1 ),
1049+ dk .stride (2 ),
10471050 dv .stride (0 ),
1048- dv .stride (2 ),
10491051 dv .stride (1 ),
1052+ dv .stride (2 ),
10501053 kv_block_indices .stride (0 ),
1051- kv_block_indices .stride (2 ),
10521054 kv_block_indices .stride (1 ),
1055+ kv_block_indices .stride (2 ),
10531056 nheads ,
10541057 seqlen_q ,
10551058 seqlen_k ,
@@ -1063,8 +1066,8 @@ def flash_attn_backward(
10631066 BLOCK = block_size ,
10641067 NUM_SEL_KV_BLOCKS = num_sel_fine_blocks ,
10651068 SEQUENCE_PARALLEL = False ,
1066- EVEN_M = (seqlen_q % block_size ) == 0 ,
1067- EVEN_N = (seqlen_k % block_size ) == 0 ,
1069+ EVEN_M = divisible_by (seqlen_q , block_size ),
1070+ EVEN_N = divisible_by (seqlen_k , block_size ),
10681071 EVEN_HEADDIM = BLOCK_HEADDIM == dim
10691072 # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
10701073 # num_warps=num_warps,
@@ -1093,10 +1096,6 @@ def forward(
10931096 fmask ,
10941097 num_grouped_queries
10951098 ):
1096- selected_block_indices , fmask = tuple (rearrange (t , 'b h i sel -> b i h sel' ) for t in (selected_block_indices , fmask ))
1097-
1098- fq , fk , fv = tuple (rearrange (t , 'b h n d -> b n h d' ) for t in (fq , fk , fv ))
1099-
11001099 dtype = fq .dtype
11011100
11021101 fq , fk , fv = tuple (t .half () for t in (fq , fk , fv ))
@@ -1111,12 +1110,10 @@ def forward(
11111110 ctx .save_for_backward (fq , fk , fv , selected_block_indices , fmask , out , lse )
11121111 ctx ._saved_variables = (block_size ,)
11131112
1114- out = rearrange (out , 'b n h d -> b h n d' )
11151113 return out .type (dtype )
11161114
11171115 @classmethod
11181116 def backward (self , ctx , do ):
1119- do = rearrange (do , 'b h n d -> b n h d' )
11201117
11211118 q , k , v , sel_block_indices , mask , out , lse = ctx .saved_tensors
11221119
@@ -1136,7 +1133,6 @@ def backward(self, ctx, do):
11361133 block_size = block_size
11371134 )
11381135
1139- dq , dk , dv = tuple (rearrange (t , 'b n h d -> b h n d' ) for t in (dq , dk , dv ))
11401136 return dq , dk , dv , None , None , None , None
11411137
11421138native_sparse_attend = NSA .apply
0 commit comments