22from native_sparse_attention_pytorch .triton_native_sparse_attention import native_sparse_attend
33
44import einx
5- from einops import rearrange , einsum
5+ from einops import rearrange , einsum , repeat
66
77assert torch .cuda .is_available ()
88
@@ -33,7 +33,8 @@ def regular_attend(
3333
3434 # rest of the indices
3535
36- has_sel_kv_blocks = indices .shape [- 1 ] > 0
36+ num_sel_kv_blocks = indices .shape [- 1 ]
37+ has_sel_kv_blocks = num_sel_kv_blocks > 0
3738
3839 if has_sel_kv_blocks :
3940 bk , bv = tuple (rearrange (t , 'b (h w) n d -> b h w n d' , h = kv_heads ) for t in (k , v ))
@@ -43,14 +44,33 @@ def regular_attend(
4344 q = rearrange (q , 'b (h w) n d -> b h (w n) d' , h = q_heads )
4445 bsim = einsum (q , sel_bk , 'b h i d, b h i j d -> b h i j' ) * scale
4546
47+ bsim = rearrange (bsim , 'b h (w i) (sel j) -> b h w i sel j' , sel = num_sel_kv_blocks , i = fine_block_size )
48+
49+ mask = rearrange (mask , 'b h (w i) sel -> b h w i sel' , i = fine_block_size )
50+ bsim = torch .where (mask [..., None ], bsim , - torch .finfo (bsim .dtype ).max )
51+
52+ sim = rearrange (sim , 'b (h w) i j -> b h w i 1 j' , h = q_heads )
53+
54+ sim = torch .cat ((sim , bsim ), dim = - 2 )
55+ sim = rearrange (sim , 'b h w i causal_and_sel j -> b h (w i) (causal_and_sel j)' )
56+
57+ sel_bv = rearrange (sel_bv , 'b h (w i) j d -> b h w i j d' , i = fine_block_size )
58+
59+ v = repeat (v , 'b (h w) j d -> b h w i j d' , h = kv_heads , i = fine_block_size )
60+ v = torch .cat ((v , sel_bv ), dim = - 2 )
61+ v = rearrange (v , 'b h w i j d -> b h (w i) j d' )
62+
4663 # attend
4764
4865 attn = sim .softmax (dim = - 1 )
4966
50- out = einsum (attn , v , 'b h i j, b h j d -> b h i d' )
67+ if has_sel_kv_blocks :
68+ out = einsum (attn , v , 'b h i j, b h i j d -> b h i d' )
69+ else :
70+ out = einsum (attn , v , 'b h i j, b h j d -> b h i d' )
5171
52- if exists (block_size ):
53- out = rearrange (out , 'b (h w) n d -> b h (w n) d' , w = w )
72+ if exists (block_size ):
73+ out = rearrange (out , 'b (h w) n d -> b h (w n) d' , w = w )
5474
5575 return out
5676
@@ -62,8 +82,8 @@ def regular_attend(
6282k = torch .randn (1 , 2 , 512 , 64 ).cuda ()
6383v = torch .randn (1 , 2 , 512 , 64 ).cuda ()
6484
65- indices = torch .zeros (1 , 2 , 512 , 1 ).long ().cuda ()
66- mask = torch .zeros (1 , 2 , 512 , 1 ).bool ().cuda ()
85+ indices = torch .zeros (1 , 2 , 512 , 0 ).long ().cuda ()
86+ mask = torch .zeros (1 , 2 , 512 , 0 ).bool ().cuda ()
6787
6888# both regular and nsa pathways `r` and `n`
6989
0 commit comments