99def exists (v ):
1010 return v is not None
1111
12+ def divisible_by (num , den ):
13+ return (num % den ) == 0
14+
1215def regular_attend (
1316 q , k , v ,
1417 indices ,
1518 mask ,
16- block_size = None ,
19+ block_size ,
1720):
18- q_heads , kv_heads = q .shape [1 ], k .shape [1 ]
21+ q_heads , seq_len , kv_heads , device = q .shape [1 ], q .shape [- 2 ], k .shape [1 ], q .device
22+ assert divisible_by (q_heads , kv_heads )
23+
24+ g = q_heads // kv_heads # `g` stands for `g`roups of query heads per kv head
25+
26+ assert divisible_by (seq_len , block_size )
27+ w = seq_len // block_size
1928
20- if exists (block_size ):
21- w = q .shape [- 2 ] // block_size
22- q , k , v = tuple (rearrange (t , 'b h (w n) d -> b (h w) n d' , n = block_size ) for t in (q , k , v ))
29+ q , k , v = tuple (rearrange (t , 'b h (w n) d -> b h w n d' , n = block_size ) for t in (q , k , v ))
2330
24- seq_len , device = q .shape [- 2 ], q .device
2531 scale = q .shape [- 1 ] ** - 0.5
2632 q = q * scale
2733
34+ q = rearrange (q , 'b (h g) ... -> b h g ...' , g = g )
35+
2836 # block causal diagonal
2937
30- sim = einsum (q , k , 'b h i d, b h j d -> b h i j' )
31- causal_mask = torch .ones ((seq_len , seq_len ), device = device , dtype = torch .bool ).triu (1 )
38+ sim = einsum (q , k , 'b h g w i d, b h w j d -> b h g w i j' )
39+ causal_mask = torch .ones ((block_size , block_size ), device = device , dtype = torch .bool ).triu (1 )
3240 sim = sim .masked_fill (causal_mask , - torch .finfo (sim .dtype ).max )
3341
3442 # rest of the indices
@@ -37,48 +45,45 @@ def regular_attend(
3745 has_sel_kv_blocks = num_sel_kv_blocks > 0
3846
3947 if has_sel_kv_blocks :
40- bk , bv = tuple ( rearrange ( t , 'b (h w) n d -> b h w n d' , h = kv_heads ) for t in ( k , v ))
48+ bk , bv = k , v
4149 sel_bk = einx .get_at ('b h [w] n d, b h i sel -> b h i (sel n) d' , bk , indices )
4250 sel_bv = einx .get_at ('b h [w] n d, b h i sel -> b h i (sel n) d' , bv , indices )
4351
44- q = rearrange (q , 'b (h w) n d -> b h (w n) d' , h = q_heads )
45- bsim = einsum (q , sel_bk , 'b h i d, b h i j d -> b h i j' )
52+ q = rearrange (q , 'b h g w n d -> b h g (w n) d' )
53+ bsim = einsum (q , sel_bk , 'b h g i d, b h i j d -> b h g i j' )
4654
47- bsim = rearrange (bsim , 'b h (w i) (sel j) -> b h w i sel j' , sel = num_sel_kv_blocks , i = fine_block_size )
55+ bsim = rearrange (bsim , 'b h g (w i) (sel j) -> b h g w i sel j' , sel = num_sel_kv_blocks , i = fine_block_size )
4856
49- mask = rearrange (mask , 'b h (w i) sel -> b h w i sel' , i = fine_block_size )
57+ mask = rearrange (mask , 'b h (w i) sel -> b h 1 w i sel' , i = fine_block_size )
5058 bsim = torch .where (mask [..., None ], bsim , - torch .finfo (bsim .dtype ).max )
5159
52- sim = rearrange (sim , 'b (h w) i j -> b h w i 1 j' , h = q_heads )
60+ sim = rearrange (sim , 'b h g w i j -> b h g w i 1 j' )
5361
5462 sim = torch .cat ((sim , bsim ), dim = - 2 )
55- sim = rearrange (sim , 'b h w i causal_and_sel j -> b h ( w i) (causal_and_sel j)' )
63+ sim = rearrange (sim , 'b h g w i causal_and_sel j -> b h g w i (causal_and_sel j)' )
5664
5765 sel_bv = rearrange (sel_bv , 'b h (w i) j d -> b h w i j d' , i = fine_block_size )
5866
59- v = repeat (v , 'b ( h w) j d -> b h w i j d' , h = kv_heads , i = fine_block_size )
67+ v = repeat (v , 'b h w j d -> b h w i j d' , i = fine_block_size )
6068 v = torch .cat ((v , sel_bv ), dim = - 2 )
61- v = rearrange (v , 'b h w i j d -> b h ( w i) j d' )
69+ v = rearrange (v , 'b h w i j d -> b h w i j d' )
6270
6371 # attend
6472
6573 attn = sim .softmax (dim = - 1 )
6674
6775 if has_sel_kv_blocks :
68- out = einsum (attn , v , 'b h i j, b h i j d -> b h i d' )
76+ out = einsum (attn , v , 'b h g w i j, b h w i j d -> b h g w i d' )
6977 else :
70- out = einsum (attn , v , 'b h i j, b h j d -> b h i d' )
71-
72- if exists (block_size ):
73- out = rearrange (out , 'b (h w) n d -> b h (w n) d' , w = w )
78+ out = einsum (attn , v , 'b h g w i j, b h j d -> b h g w i d' )
7479
75- return out
80+ return rearrange ( out , 'b h g w n d -> b (h g) (w n) d' )
7681
7782# mock inputs
7883
7984fine_block_size = 16
8085
81- q = torch .randn (1 , 2 , 512 , 64 ).cuda ()
86+ q = torch .randn (1 , 4 , 512 , 64 ).cuda ()
8287k = torch .randn (1 , 2 , 512 , 64 ).cuda ()
8388v = torch .randn (1 , 2 , 512 , 64 ).cuda ()
8489
@@ -97,7 +102,7 @@ def regular_attend(
97102
98103# triton nsa forwards and backwards
99104
100- nsa_out = native_sparse_attend (nq , nk , nv , fine_block_size , indices , mask , 1 )
105+ nsa_out = native_sparse_attend (nq , nk , nv , fine_block_size , indices , mask )
101106nsa_out .sum ().backward ()
102107
103108# asserts
0 commit comments