Skip to content

Commit dc8f568

Browse files
committed
get some more code down
1 parent bd6e3f9 commit dc8f568

File tree

3 files changed

+76
-20
lines changed

3 files changed

+76
-20
lines changed

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,34 @@ from torch import nn
2222
to_keys = nn.Linear(512, 1024, bias = False)
2323
to_values = nn.Linear(512, 1024, bias = False)
2424

25-
C = torch.randn(512, 512)
25+
input_covariance = torch.randn(512, 512)
2626

2727
wrapped_to_keys = Rank1EditModule(
2828
to_keys,
29-
C = C,
29+
C = input_covariance,
3030
is_key_proj = True,
3131
num_finetune_prompts = 32
3232
)
3333

3434
wrapped_to_values = Rank1EditModule(
3535
to_values,
36-
C = C,
36+
C = input_covariance,
3737
num_finetune_prompts = 32
3838
)
3939

4040
prompt_ids = torch.arange(4).long()
41-
text_enc = torch.randn(4, 1024, 512)
42-
concept_ids = torch.randint(0, 1024, (4,))
41+
text_enc = torch.randn(4, 256, 512)
42+
concept_ids = torch.randint(0, 256, (4,))
4343

4444
keys = wrapped_to_keys(prompt_ids, text_enc, concept_ids)
4545
values = wrapped_to_values(prompt_ids, text_enc, concept_ids)
4646
```
4747

48+
## Todo
49+
50+
- [ ] take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update
51+
- [ ] instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)
52+
4853
## Citations
4954

5055
```bibtex

perfusion_pytorch/perfusion.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
def exists(val):
1212
return val is not None
1313

14-
# main contribution of paper
1514
# a module that wraps the keys and values projection of the cross attentions to text encodings
1615

1716
class Rank1EditModule(Module):
@@ -22,7 +21,8 @@ def __init__(
2221
key_or_values_proj: nn.Linear,
2322
*,
2423
num_finetune_prompts: int,
25-
C: Tensor,
24+
C: Tensor, # covariance of input, precomputed from 100K laion text
25+
text_seq_len: int = 256,
2626
is_key_proj: bool = False,
2727
input_decay = 0.99,
2828
train_beta = 0.75,
@@ -34,7 +34,7 @@ def __init__(
3434
assert not exists(key_or_values_proj.bias), 'key value projection in attention should not have bias'
3535

3636
self.weight = key_or_values_proj.weight
37-
dim_input = self.weight.shape[-1]
37+
dim_output, dim_input = self.weight.shape
3838

3939
self.is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
4040

@@ -48,10 +48,13 @@ def __init__(
4848
# they exponentially smooth the text encoding inputs during training
4949
# in addition to a lowered learning rate on the text encodings
5050

51+
self.text_seq_len = text_seq_len
52+
5153
self.register_buffer('initted', torch.zeros(num_finetune_prompts).bool())
5254
self.register_buffer('ema_concept_text_enc', torch.zeros(num_finetune_prompts, dim_input))
55+
self.register_buffer('outputs', torch.zeros(num_finetune_prompts, text_seq_len, dim_output))
5356

54-
# buffers
57+
# C in the paper, inverse precomputed
5558

5659
self.register_buffer('C_inv', torch.inverse(C))
5760

@@ -62,13 +65,22 @@ def forward(
6265
text_enc: Tensor,
6366
concept_indices: Tensor
6467
):
68+
assert text_enc.shape[-2] == self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}'
69+
6570
"""
6671
following the pseudocode of Algorithm 1 in appendix
72+
73+
einstein notation:
74+
b - batch
75+
n - sequence
76+
d - feature dimension
77+
i - input dimension
78+
o - output dimension
6779
"""
6880

6981
batch, device = text_enc.shape[0], self.initted.device
7082

71-
weights, decay = self.weight, self.input_decay
83+
weights, decay, Ci = self.weight, self.input_decay, self.C_inv
7284

7385
# beta and temperature depends on whether training or inference
7486

@@ -86,20 +98,59 @@ def forward(
8698
# during training, keep track of exponentially smoothed input
8799

88100
if self.training:
89-
batch_initted = self.initted[prompt_ids]
101+
102+
# get the initialization state
103+
# as well as the exponentially smoothed text encodings
104+
105+
initted = self.initted[prompt_ids]
106+
all_initted = initted.all()
107+
90108
ema_concept_text_enc = self.ema_concept_text_enc[prompt_ids]
109+
outputs = self.outputs[prompt_ids]
110+
111+
# if any in the batch is not initialized, initialize
91112

92-
ema_concept_text_enc = torch.where(
93-
rearrange(batch_initted, 'b -> b 1'),
94-
ema_concept_text_enc,
95-
concept_text_enc
96-
)
113+
if not all_initted:
114+
ema_concept_text_enc = torch.where(
115+
rearrange(initted, 'b -> b 1'),
116+
ema_concept_text_enc,
117+
concept_text_enc
118+
)
119+
120+
outputs = torch.where(
121+
rearrange(initted, 'b -> b 1 1'),
122+
outputs,
123+
einsum('o i, b n i -> b n o', weights, text_enc)
124+
)
97125

98126
# update using exponential moving average
99127

100128
concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
101129

102-
self.initted[prompt_ids] = True
103-
self.ema_concept_text_enc[prompt_ids] = concept_text_enc
130+
if not all_initted:
131+
self.initted[prompt_ids] = True
132+
self.ema_concept_text_enc[prompt_ids] = ema_concept_text_enc
133+
self.outputs[prompt_ids] = outputs
134+
135+
# make it easier to match with paper
136+
137+
i, o, W = ema_concept_text_enc, outputs, weights
138+
139+
# main contribution eq (3)
140+
141+
i_energy = einsum('b d, b d -> b', i @ Ci, i)
142+
i_energy = rearrange(i_energy, '... -> ... 1 1')
143+
144+
sim = einsum('b n d, b d -> b n', text_enc, i @ Ci)
145+
sim = rearrange(sim, '... -> ... 1')
146+
147+
sigmoid_term = (((sim / i_energy) - beta) / temperature).sigmoid()
148+
149+
orig_output = einsum('b n i, o i -> b n o', text_enc, W)
150+
151+
concept_output = einsum('b i, o i -> b o', i, W)
152+
concept_output = rearrange(concept_output, 'b d -> b 1 d')
153+
154+
W_em_orthogonal_term = orig_output - (sim * concept_output / i_energy)
104155

105-
return einsum('b n i, o i -> b n o', text_enc, weights)
156+
return W_em_orthogonal_term + sigmoid_term * o

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.1',
6+
version = '0.0.2',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)