Skip to content

Commit bd50036

Browse files
committed
clarified with author that superclass output is initialized using the mean of the first batch of superclass encodings
1 parent fa25768 commit bd50036

File tree

3 files changed

+36
-42
lines changed

3 files changed

+36
-42
lines changed

README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ $ pip install perfusion-pytorch
2222

2323
```python
2424
import torch
25-
from perfusion_pytorch import Rank1EditModule
2625
from torch import nn
2726

27+
from perfusion_pytorch import Rank1EditModule
28+
2829
to_keys = nn.Linear(768, 320, bias = False)
2930
to_values = nn.Linear(768, 320, bias = False)
3031

@@ -33,37 +34,34 @@ input_covariance = torch.randn(768, 768)
3334
wrapped_to_keys = Rank1EditModule(
3435
to_keys,
3536
C = input_covariance,
36-
is_key_proj = True,
37-
num_finetune_prompts = 32
37+
is_key_proj = True
3838
)
3939

4040
wrapped_to_values = Rank1EditModule(
4141
to_values,
42-
C = input_covariance,
43-
num_finetune_prompts = 32
42+
C = input_covariance
4443
)
4544

46-
prompt_ids = torch.arange(4).long() # id of each training prompt, so that it can automatically keep track of the EMA
4745
text_enc = torch.randn(4, 77, 768) # regular input
4846
text_enc_with_superclass = torch.randn(4, 77, 768) # init_input in algorithm 1, for key-locking
4947
concept_ids = torch.randint(0, 77, (4,))
5048

5149
keys = wrapped_to_keys(
5250
text_enc,
5351
text_enc_with_superclass,
54-
concept_ids,
55-
prompt_ids = prompt_ids
52+
concept_ids
5653
)
5754

5855
values = wrapped_to_values(
5956
text_enc,
6057
text_enc_with_superclass,
61-
concept_ids,
62-
prompt_ids = prompt_ids
58+
concept_ids
6359
)
6460

6561
# after much training ...
66-
# simply omit the prompt ids
62+
63+
wrapped_to_keys.eval()
64+
wrapped_to_values.eval()
6765

6866
keys = wrapped_to_keys(
6967
text_enc,
@@ -80,6 +78,8 @@ values = wrapped_to_values(
8078

8179
## Todo
8280

81+
- [ ] handle rank-1 update for multiple concepts
82+
8383
- [x] take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update
8484
- [x] instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)
8585

perfusion_pytorch/perfusion.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.nn import Module
88
import torch.nn.functional as F
99

10-
from einops import rearrange
10+
from einops import rearrange, reduce
1111

1212
from opt_einsum import contract as opt_einsum
1313

@@ -45,7 +45,8 @@ def calculate_input_covariance(
4545

4646
all_embeds = torch.cat((all_embeds), dim = 0)
4747
all_embeds = rearrange(all_embeds, 'n d -> d n')
48-
return torch.cov(all_embeds, **cov_kwargs)
48+
49+
return torch.cov(all_embeds, correction = 0, **cov_kwargs)
4950

5051
# a module that wraps the keys and values projection of the cross attentions to text encodings
5152

@@ -56,7 +57,6 @@ def __init__(
5657
self,
5758
key_or_values_proj: nn.Linear,
5859
*,
59-
num_finetune_prompts: int,
6060
C: Tensor, # covariance of input, precomputed from 100K laion text
6161
text_seq_len: int = 77,
6262
is_key_proj: bool = False,
@@ -90,14 +90,14 @@ def __init__(
9090
# for exponentially smoothing the inputs
9191
# will smooth both concept and superclass token inputs
9292

93-
self.register_buffer('initted', torch.zeros(num_finetune_prompts).bool())
94-
self.register_buffer('ema_concept_text_encs', torch.zeros(num_finetune_prompts, dim_input))
93+
self.register_buffer('initted', Tensor([False]))
94+
self.register_buffer('ema_concept_text_encs', torch.zeros(dim_input))
9595

9696
# superclass outputs - only optimized for values, but not keys
9797

9898
self.is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
9999

100-
self.superclass_outputs = nn.Parameter(torch.zeros(num_finetune_prompts, dim_output), requires_grad = not is_key_proj)
100+
self.superclass_outputs = nn.Parameter(torch.zeros(dim_output), requires_grad = not is_key_proj)
101101

102102
# C in the paper, inverse precomputed
103103

@@ -150,57 +150,52 @@ def forward(
150150
concept_indices = rearrange(concept_indices, 'b -> b 1')
151151

152152
concept_text_enc = text_enc[batch_indices, concept_indices]
153-
concept_text_enc = rearrange(concept_text_enc, 'b 1 d -> b d')
153+
concept_text_enc = reduce(concept_text_enc, 'b 1 d -> d', 'mean')
154154

155-
# only if training, and if prompt ids are given
155+
# only if training
156156
# do exponential smoothing of the inputs, both concept and superclass
157157

158158
if exists(text_enc_with_superclass):
159159
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
160-
superclass_text_enc = rearrange(superclass_text_enc, 'b 1 d -> b d')
160+
superclass_text_enc = reduce(superclass_text_enc, 'b 1 d -> d', 'mean')
161161

162-
superclass_output = einsum('b i, o i -> b o', superclass_text_enc, weights)
162+
superclass_output = einsum('i, o i -> o', superclass_text_enc, weights)
163163

164164
if self.training and exists(prompt_ids):
165165
# get the initialization state
166166
# as well as the exponentially smoothed text encodings
167167

168-
initted = self.initted[prompt_ids]
169-
all_initted = initted.all()
168+
initted = self.initted.item()
170169

171170
ema_concept_text_enc = self.ema_concept_text_encs[prompt_ids]
172171

173172
# store the superclass i* if not all initialized
174173
# else fetch it from the buffer
175174

176-
if not all_initted:
175+
if not initted:
177176
assert exists(superclass_output), 'text_enc_with_superclass must be passed in for the first epoch for all prompts to initialize the module correctly'
178177

179178
non_initted_prompt_ids = prompt_ids[~initted]
180179

181180
# for the prompt ids not initialized yet, hard copy over the initial superclass outputs
182-
self.superclass_outputs[non_initted_prompt_ids].data.copy_(superclass_output)
181+
self.superclass_outputs.data.copy_(superclass_output)
183182

184-
superclass_output = self.superclass_outputs[prompt_ids]
183+
superclass_output = self.superclass_outputs
185184

186185
# if any in the batch is not initialized, initialize
187186

188-
if not all_initted:
189-
ema_concept_text_enc = torch.where(
190-
rearrange(initted, 'b -> b 1'),
191-
ema_concept_text_enc,
192-
concept_text_enc
193-
)
187+
if not initted:
188+
ema_concept_text_enc = concept_text_enc
194189

195190
# exponential moving average for concept input encoding
196191

197192
concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
198193

199194
# store
200195

201-
if not all_initted:
202-
self.initted[prompt_ids] = True
203-
self.ema_concept_text_encs[prompt_ids] = ema_concept_text_enc
196+
if not initted:
197+
self.initted.data.copy_(Tensor([True]))
198+
self.ema_concept_text_encs.data.copy_(ema_concept_text_enc)
204199

205200
# take care of the output
206201
# for the keys, make sure to turn off gradients as it is 'locked'
@@ -214,19 +209,18 @@ def forward(
214209

215210
# main contribution eq (3)
216211

217-
i_energy = opt_einsum('b o, o i, b i -> b', i, Ci, i)
218-
i_energy = rearrange(i_energy, '... -> ... 1 1')
212+
i_energy = opt_einsum('o, o i, i ->', i, Ci, i)
219213

220-
sim = opt_einsum('b n o, o i, b i -> b n', text_enc, Ci, i)
214+
sim = opt_einsum('b n o, o i, i -> b n', text_enc, Ci, i)
221215
sim = rearrange(sim, '... -> ... 1')
222216

223217
sigmoid_term = (((sim / i_energy) - beta) / temperature).sigmoid()
224218

225219
text_enc_output = einsum('b n i, o i -> b n o', text_enc, W)
226220

227-
concept_output = einsum('b i, o i -> b o', i, W)
228-
concept_output = rearrange(concept_output, 'b d -> b 1 d')
221+
concept_output = einsum('i, o i -> o', i, W)
222+
concept_output = rearrange(concept_output, 'd -> 1 1 d')
229223

230224
W_em_orthogonal_term = text_enc_output - (sim * concept_output / i_energy)
231225

232-
return W_em_orthogonal_term + sigmoid_term * rearrange(o, 'b d -> b 1 d')
226+
return W_em_orthogonal_term + sigmoid_term * rearrange(o, 'd -> 1 1 d')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.14',
6+
version = '0.0.15',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)