@@ -165,6 +165,7 @@ def __init__(
165165 query_bias = True ,
166166 flash = True ,
167167 window_size = None ,
168+ num_memory_kv : int = 0 ,
168169 efficient_attn_config : Config = Config (True , True , True )
169170 ):
170171 super ().__init__ ()
@@ -178,6 +179,7 @@ def __init__(
178179 e - dimension (pairwise rep)
179180 i - source sequence
180181 j - context sequence
182+ m - memory key / value seq
181183 """
182184
183185 dim_inner = dim_head * heads
@@ -196,6 +198,12 @@ def __init__(
196198 self .to_kv = nn .Linear (dim , dim_inner * 2 , bias = False )
197199 self .to_out = nn .Linear (dim_inner , dim , bias = False )
198200
201+ self .memory_kv = None
202+
203+ if num_memory_kv > 0 :
204+ self .memory_kv = nn .Parameter (torch .zeros (2 , heads , num_memory_kv , dim_head ))
205+ nn .init .normal_ (self .memory_kv , std = 0.02 )
206+
199207 # gating of value
200208 # allows attention to attend to nothing
201209
@@ -230,7 +238,8 @@ def forward(
230238 out = self .attend (
231239 q , k , v ,
232240 attn_bias = attn_bias ,
233- mask = mask
241+ mask = mask ,
242+ memory_kv = self .memory_kv
234243 )
235244
236245 # merge heads
@@ -315,7 +324,8 @@ def local_attn(
315324 k : Float ['b h n d' ],
316325 v : Float ['b h n d' ],
317326 mask : Bool ['b n' ] | None = None ,
318- attn_bias : Float ['... n n' ] | Float ['... nw w (w*2)' ] | None = None
327+ attn_bias : Float ['... n n' ] | Float ['... nw w (w*2)' ] | None = None ,
328+ memory_kv : Float ['2 h m d' ] | None = None
319329 ) -> Float ['b h n d' ]:
320330 """
321331 simple local attention with a radius of 1 window size
@@ -363,6 +373,24 @@ def local_attn(
363373
364374 q = q * scale
365375
376+ # append memory key / values for local attention windows
377+
378+ if exists (memory_kv ):
379+ batch , seq , num_mem_kv = k .shape [0 ], k .shape [2 ], memory_kv .shape [- 2 ]
380+
381+ mk , mv = memory_kv
382+ mk , mv = tuple (repeat (t , 'h m d -> b h n m d' , b = batch , n = seq ) for t in (mk , mv ))
383+ k = torch .cat ((mk , k ), dim = - 2 )
384+ v = torch .cat ((mv , v ), dim = - 2 )
385+
386+ if exists (attn_bias ):
387+ attn_bias = pad_at_dim (attn_bias , (num_mem_kv , 0 ), value = 0. )
388+
389+ if exists (mask ):
390+ mask = pad_at_dim (mask , (num_mem_kv , 0 ), value = True )
391+
392+ # similarity
393+
366394 sim = einsum (q , k , "... i d, ... j d -> ... i j" )
367395
368396 if exists (attn_bias ):
@@ -399,6 +427,7 @@ def forward(
399427 v : Float ['b h j d' ],
400428 mask : Bool ['b j' ] | None = None ,
401429 attn_bias : Float ['... i j' ] | Float ['... nw w (w*2)' ] | None = None ,
430+ memory_kv : Float ['2 h m d' ] | None = None
402431 ) -> Float ['b h i d' ]:
403432
404433 is_windowed_attn_bias = None
@@ -410,10 +439,26 @@ def forward(
410439 # todo (handle attn bias efficiently)
411440
412441 if self .is_local_attn :
413- return self .local_attn (q , k , v , mask = mask , attn_bias = attn_bias )
442+ return self .local_attn (q , k , v , mask = mask , attn_bias = attn_bias , memory_kv = memory_kv )
414443
415444 assert not exists (is_windowed_attn_bias ) or not is_windowed_attn_bias
416445
446+ # append memory key / values
447+
448+ if exists (memory_kv ):
449+ batch , num_mem_kv = q .shape [0 ], memory_kv .shape [- 2 ]
450+
451+ mk , mv = memory_kv
452+ mk , mv = tuple (repeat (t , 'h m d -> b h m d' , b = batch ) for t in (mk , mv ))
453+ k = torch .cat ((mk , k ), dim = - 2 )
454+ v = torchc .at ((mv , v ), dim = - 2 )
455+
456+ if exists (attn_bias ):
457+ attn_bias = pad_at_dim (attn_bias , (num_mem_kv , 0 ), value = 0. )
458+
459+ if exists (mask ):
460+ mask = pad_at_dim (mask , (num_mem_kv , 0 ), value = True )
461+
417462 # forward to using flash attention if applicable
418463
419464 can_use_flash = self .flash and not exists (attn_bias ), 'flash attention does not support attention bias with gradients'
0 commit comments