@@ -15,16 +15,36 @@ def regular_attend(
1515 mask ,
1616 block_size = None ,
1717):
18+ q_heads , kv_heads = q .shape [1 ], k .shape [1 ]
19+
1820 if exists (block_size ):
1921 w = q .shape [- 2 ] // block_size
2022 q , k , v = tuple (rearrange (t , 'b h (w n) d -> b (h w) n d' , n = block_size ) for t in (q , k , v ))
2123
2224 seq_len , device = q .shape [- 2 ], q .device
2325 scale = q .shape [- 1 ] ** - 0.5
26+ q = q * scale
27+
28+ # block causal diagonal
2429
25- sim = einsum (q , k , 'b h i d, b h j d -> b h i j' ) * scale
30+ sim = einsum (q , k , 'b h i d, b h j d -> b h i j' )
2631 causal_mask = torch .ones ((seq_len , seq_len ), device = device , dtype = torch .bool ).triu (1 )
2732 sim = sim .masked_fill (causal_mask , - torch .finfo (sim .dtype ).max )
33+
34+ # rest of the indices
35+
36+ has_sel_kv_blocks = indices .shape [- 1 ] > 0
37+
38+ if has_sel_kv_blocks :
39+ bk , bv = tuple (rearrange (t , 'b (h w) n d -> b h w n d' , h = kv_heads ) for t in (k , v ))
40+ sel_bk = einx .get_at ('b h [w] n d, b h i sel -> b h i (sel n) d' , bk , indices )
41+ sel_bv = einx .get_at ('b h [w] n d, b h i sel -> b h i (sel n) d' , bv , indices )
42+
43+ q = rearrange (q , 'b (h w) n d -> b h (w n) d' , h = q_heads )
44+ bsim = einsum (q , sel_bk , 'b h i d, b h i j d -> b h i j' ) * scale
45+
46+ # attend
47+
2848 attn = sim .softmax (dim = - 1 )
2949
3050 out = einsum (attn , v , 'b h i j, b h j d -> b h i d' )
0 commit comments