1515# einstein notation
1616
1717import einx
18- from einops import einsum , repeat , rearrange
18+ from einops import einsum , repeat , rearrange , reduce
1919from einops .layers .torch import Rearrange
2020
2121# b - batch
@@ -66,6 +66,9 @@ def round_down_mult(n, mult):
6666def round_up_mult (n , mult ):
6767 return ceil (n / mult ) * mult
6868
69+ def divisible_by (num , den ):
70+ return (num % den ) == 0
71+
6972def pad_at_dim (t , pad , dim = - 1 , value = 0. ):
7073 dims_from_right = (- dim - 1 ) if dim < 0 else (t .ndim - dim - 1 )
7174 zeros = ((0 , 0 ) * dims_from_right )
@@ -83,6 +86,7 @@ def __init__(
8386 compress_block_size ,
8487 selection_block_size ,
8588 num_selected_blocks ,
89+ num_kv_heads = None ,
8690 num_compressed_mem_kv = 4 ,
8791 norm = True ,
8892 use_diff_topk = False ,
@@ -91,12 +95,25 @@ def __init__(
9195 strategy_combine_mlp : Module | None = None
9296 ):
9397 super ().__init__ ()
98+
99+ # attention heads
100+ # handling gqa if `num_kv_heads` is set
101+
102+ num_kv_heads = default (num_kv_heads , heads )
103+ assert num_kv_heads <= heads and divisible_by (heads , num_kv_heads )
104+
94105 self .heads = heads
106+ self .num_kv_heads = num_kv_heads
107+ self .num_grouped_queries = heads // num_kv_heads
108+
109+ # scale
110+
95111 self .scale = dim_head ** - 0.5
96112
97113 assert compress_block_size == selection_block_size , 'start off with compressed being equal to selection block sizes'
98114
99115 dim_inner = dim_head * heads
116+ dim_kv_inner = dim_head * num_kv_heads
100117
101118 self .norm = nn .RMSNorm (dim ) if norm else nn .Identity ()
102119
@@ -106,7 +123,11 @@ def __init__(
106123
107124 # qkv
108125
109- self .to_qkv = nn .Linear (dim , dim_inner * 3 , bias = False )
126+ qkv_split = (dim_inner , dim_kv_inner , dim_kv_inner )
127+
128+ self .to_qkv = nn .Linear (dim , sum (qkv_split ), bias = False )
129+
130+ self .qkv_split = qkv_split
110131
111132 # sliding window strategy
112133
@@ -129,10 +150,10 @@ def __init__(
129150
130151 self .split_compress_window = Rearrange ('b h (w n) d -> b h w n d' , n = compress_block_size )
131152
132- self .compress_mem_kv = nn .Parameter (torch .zeros (2 , heads , num_compressed_mem_kv , dim_head ))
153+ self .compress_mem_kv = nn .Parameter (torch .zeros (2 , num_kv_heads , num_compressed_mem_kv , dim_head ))
133154
134- self .k_intrablock_positions = nn .Parameter (torch .zeros (heads , compress_block_size , dim_head ))
135- self .v_intrablock_positions = nn .Parameter (torch .zeros (heads , compress_block_size , dim_head ))
155+ self .k_intrablock_positions = nn .Parameter (torch .zeros (num_kv_heads , compress_block_size , dim_head ))
156+ self .v_intrablock_positions = nn .Parameter (torch .zeros (num_kv_heads , compress_block_size , dim_head ))
136157
137158 if not exists (compress_mlp ):
138159 compress_dim = compress_block_size * dim_head
@@ -168,7 +189,7 @@ def __init__(
168189
169190 # split and merging heads
170191
171- self .split_heads = Rearrange ('b n (h d) -> b h n d' , h = heads )
192+ self .split_heads = Rearrange ('b n (h d) -> b h n d' , d = dim_head )
172193 self .merge_heads = Rearrange ('b h n d -> b n (h d)' )
173194
174195 # combining heads
@@ -194,7 +215,7 @@ def forward(
194215
195216 # queries, keys, values
196217
197- q , k , v = self .to_qkv (inp ).chunk ( 3 , dim = - 1 )
218+ q , k , v = self .to_qkv (inp ).split ( self . qkv_split , dim = - 1 )
198219
199220 q , k , v = map (self .split_heads , (q , k , v ))
200221
@@ -218,6 +239,8 @@ def forward(
218239 ck = cat ((mem_ck , ck ), dim = - 2 )
219240 cv = cat ((mem_cv , cv ), dim = - 2 )
220241
242+ ck , cv = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (ck , cv ))
243+
221244 csim = einsum (q , ck , 'b h i d, b h j d -> b h i j' ) * self .scale
222245
223246 cq_seq = arange (seq_len , device = device )
@@ -241,8 +264,13 @@ def forward(
241264
242265 # 2. fine attention over selected based on compressed attention logits
243266
267+
244268 importance_scores = cattn [..., num_mem_compress_kv :]
245269
270+ # for gqa, we will average the compressed attention across each grouped queries (per key / values)
271+
272+ importance_scores = reduce (importance_scores , 'b (grouped_queries h) ... -> b h ...' , 'mean' , grouped_queries = self .num_grouped_queries )
273+
246274 num_selected = min (self .num_selected_blocks , importance_scores .shape [- 1 ])
247275
248276 fq = rotated_q
@@ -273,13 +301,13 @@ def forward(
273301 # handle block causal diagonal in the diagram, but run experiments without to see
274302
275303 fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
276- fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = heads )
304+ fine_window_seq = repeat (fine_window_seq , 'n -> b h n 1' , b = batch , h = self . num_kv_heads )
277305 selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
278306
279307 fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
280308
281309 causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
282- causal_mask = repeat (causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = heads )
310+ causal_mask = repeat (causal_mask , 'i j -> b h (w i) 1 j' , w = num_fine_blocks , b = batch , h = self . num_kv_heads )
283311
284312 fmask = cat ((fmask , causal_mask ), dim = - 2 )
285313 fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
@@ -312,6 +340,8 @@ def forward(
312340
313341 # fine attention
314342
343+ fk , fv , fmask = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv , fmask ))
344+
315345 fsim = einsum (fq , fk , 'b h i d, b h i j d -> b h i j' ) * self .scale
316346
317347 fsim = fsim .masked_fill (~ fmask , mask_value )
@@ -327,6 +357,8 @@ def forward(
327357 seq_len = fk .shape [- 2 ]
328358 fmask = causal_mask = torch .ones ((seq_len , seq_len ), device = device , dtype = torch .bool ).tril ()
329359
360+ fk , fv = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv ))
361+
330362 fsim = einsum (fq , fk , 'b h i d, b h j d -> b h i j' ) * self .scale
331363
332364 fsim = fsim .masked_fill (~ fmask , mask_value )
@@ -337,10 +369,16 @@ def forward(
337369
338370 # 3. overlapping sliding window, this is unsurprising and expected
339371
372+ sq = rotated_q
373+ sk = rotated_k
374+ sv = v
375+
376+ sk , sv = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (sk , sv ))
377+
340378 if exists (sliding_window_flex_mask ):
341- sliding_window_attn_out = flex_attention (rotated_q , rotated_k , v , block_mask = sliding_window_flex_mask )
379+ sliding_window_attn_out = flex_attention (sq , sk , sv , block_mask = sliding_window_flex_mask )
342380 else :
343- sliding_window_attn_out = self .sliding_window (rotated_q , rotated_k , v )
381+ sliding_window_attn_out = self .sliding_window (sq , sk , sv )
344382
345383 # combine strategies
346384
0 commit comments