Skip to content

Commit 42d84bd

Browse files
committed
make a rough pass over the initialization logic + exponential moving average of the concept text encoding token
1 parent b1fdfae commit 42d84bd

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,44 @@ Implementation of <a href="https://arxiv.org/abs/2305.01644">Key-Locked Rank One
66

77
It seems they successfully applied the Rank-1 editing technique from a <a href="https://arxiv.org/abs/2202.05262">memory editing paper for LLM</a>, with a few improvements. They also identified that the keys determine the "where" of the new concept, while the values determine the "what", and propose local / global-key locking to a superclass concept (while learning the values).
88

9+
## Install
10+
11+
```bash
12+
$ pip install perfusion-pytorch
13+
```
14+
15+
## Usage
16+
17+
```python
18+
import torch
19+
from perfusion_pytorch import Rank1EditModule
20+
from torch import nn
21+
22+
to_keys = nn.Linear(512, 1024, bias = False)
23+
to_values = nn.Linear(512, 1024, bias = False)
24+
25+
C = torch.randn(512, 512)
26+
27+
wrapped_to_keys = Rank1EditModule(
28+
to_keys,
29+
C = C,
30+
num_finetune_prompts = 32
31+
)
32+
33+
wrapped_to_values = Rank1EditModule(
34+
to_values,
35+
C = C,
36+
num_finetune_prompts = 32
37+
)
38+
39+
prompt_ids = torch.arange(4).long()
40+
text_enc = torch.randn(4, 1024, 512)
41+
concept_ids = torch.randint(0, 1024, (4,))
42+
43+
keys = wrapped_to_keys(prompt_ids, text_enc, concept_ids)
44+
values = wrapped_to_values(prompt_ids, text_enc, concept_ids)
45+
```
46+
947
## Citations
1048

1149
```bibtex

perfusion_pytorch/perfusion.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
self,
2222
key_or_values_proj: nn.Linear,
2323
*,
24+
num_finetune_prompts: int,
2425
C: Tensor,
2526
input_decay = 0.99,
2627
train_beta = 0.75,
@@ -32,6 +33,7 @@ def __init__(
3233
assert not exists(key_or_values_proj.bias), 'key value projection in attention should not have bias'
3334

3435
self.weight = key_or_values_proj.weight
36+
dim_input = self.weight.shape[-1]
3537

3638
self.train_beta = train_beta
3739
self.train_temperature = train_temperature
@@ -40,18 +42,61 @@ def __init__(
4042

4143
self.input_decay = input_decay
4244

45+
# they exponentially smooth the text encoding inputs during training
46+
# in addition to a lowered learning rate on the text encodings
47+
48+
self.register_buffer('initted', torch.zeros(num_finetune_prompts).bool())
49+
self.register_buffer('ema_concept_text_enc', torch.zeros(num_finetune_prompts, dim_input))
50+
4351
# buffers
4452

4553
self.register_buffer('C_inv', torch.inverse(C))
4654

4755
@beartype
4856
def forward(
4957
self,
58+
prompt_ids: Tensor,
5059
text_enc: Tensor,
5160
concept_indices: Tensor
5261
):
5362
"""
5463
following the pseudocode of Algorithm 1 in appendix
5564
"""
5665

57-
return text_enc
66+
batch, device = text_enc.shape[0], self.initted.device
67+
68+
weights, decay = self.weight, self.input_decay
69+
70+
# beta and temperature depends on whether training or inference
71+
72+
beta, temperature = (self.train_beta, self.train_temperature) if self.training else (self.eval_beta, self.eval_temperature)
73+
74+
# extract the concept text encoding input
75+
76+
batch_indices = torch.arange(batch, device = device)
77+
batch_indices = rearrange(batch_indices, 'b -> b 1')
78+
concept_indices = rearrange(concept_indices, 'b -> b 1')
79+
80+
concept_text_enc = text_enc[batch_indices, concept_indices]
81+
concept_text_enc = rearrange(concept_text_enc, 'b 1 d -> b d')
82+
83+
# during training, keep track of exponentially smoothed input
84+
85+
if self.training:
86+
batch_initted = self.initted[prompt_ids]
87+
ema_concept_text_enc = self.ema_concept_text_enc[prompt_ids]
88+
89+
ema_concept_text_enc = torch.where(
90+
rearrange(batch_initted, 'b -> b 1'),
91+
ema_concept_text_enc,
92+
concept_text_enc
93+
)
94+
95+
# update using exponential moving average
96+
97+
ema_concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
98+
99+
self.initted[prompt_ids] = True
100+
self.ema_concept_text_enc[prompt_ids] = ema_concept_text_enc
101+
102+
return einsum('b n i, o i -> b n o', text_enc, weights)

0 commit comments

Comments
 (0)