Skip to content

Commit bfbad05

Browse files
committed
bring back the logic for automatically initting EMA of concept and superclass inputs
1 parent b42836f commit bfbad05

File tree

3 files changed

+88
-13
lines changed

3 files changed

+88
-13
lines changed

README.md

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,39 @@ wrapped_to_values = Rank1EditModule(
4141
num_finetune_prompts = 32
4242
)
4343

44+
prompt_ids = torch.arange(4).long() # id of each training prompt, so that it can automatically keep track of the EMA
4445
text_enc = torch.randn(4, 77, 768) # regular input
4546
text_enc_with_superclass = torch.randn(4, 77, 768) # init_input in algorithm 1, for key-locking
4647
concept_ids = torch.randint(0, 77, (4,))
4748

48-
keys = wrapped_to_keys(text_enc, text_enc_with_superclass, concept_ids)
49-
values = wrapped_to_values(text_enc, text_enc_with_superclass, concept_ids)
49+
keys = wrapped_to_keys(
50+
text_enc,
51+
text_enc_with_superclass,
52+
concept_ids,
53+
prompt_ids = prompt_ids
54+
)
55+
56+
values = wrapped_to_values(
57+
text_enc,
58+
text_enc_with_superclass,
59+
concept_ids,
60+
prompt_ids = prompt_ids
61+
)
62+
63+
# after much training ...
64+
# simply omit the prompt ids
65+
66+
keys = wrapped_to_keys(
67+
text_enc,
68+
text_enc_with_superclass,
69+
concept_ids
70+
)
71+
72+
values = wrapped_to_values(
73+
text_enc,
74+
text_enc_with_superclass,
75+
concept_ids
76+
)
5077
```
5178

5279
## Todo

perfusion_pytorch/perfusion.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from beartype.typing import Union
33

44
import torch
5-
from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor
5+
from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor, Optional
66
from torch.nn import Module
77
import torch.nn.functional as F
88

@@ -15,6 +15,8 @@
1515
def exists(val):
1616
return val is not None
1717

18+
IndicesTensor = Union[LongTensor, IntTensor]
19+
1820
# a module that wraps the keys and values projection of the cross attentions to text encodings
1921

2022
class Rank1EditModule(Module):
@@ -51,6 +53,13 @@ def __init__(
5153

5254
self.text_seq_len = text_seq_len
5355

56+
# for exponentially smoothing the inputs
57+
# will smooth both concept and superclass token inputs
58+
59+
self.register_buffer('initted', torch.zeros(num_finetune_prompts).bool())
60+
self.register_buffer('ema_concept_text_enc', torch.zeros(num_finetune_prompts, dim_input))
61+
self.register_buffer('ema_superclass_text_enc', torch.zeros(num_finetune_prompts, dim_input))
62+
5463
# C in the paper, inverse precomputed
5564

5665
self.register_buffer('C_inv', torch.inverse(C))
@@ -60,7 +69,9 @@ def forward(
6069
self,
6170
text_enc: FloatTensor,
6271
text_enc_with_superclass: FloatTensor,
63-
concept_indices: Union[IntTensor, LongTensor]
72+
concept_indices: IndicesTensor,
73+
*,
74+
prompt_ids: Optional[IndicesTensor] = None
6475
):
6576
assert text_enc.shape[-2] == self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}'
6677

@@ -98,21 +109,58 @@ def forward(
98109
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
99110
superclass_text_enc = rearrange(superclass_text_enc, 'b 1 d -> b d')
100111

112+
# only if training, and if prompt ids are given
113+
# do exponential smoothing of the inputs, both concept and superclass
114+
115+
if self.training and exists(prompt_ids):
116+
# get the initialization state
117+
# as well as the exponentially smoothed text encodings
118+
119+
initted = self.initted[prompt_ids]
120+
initted = rearrange(initted, 'b -> b 1')
121+
all_initted = initted.all()
122+
123+
ema_concept_text_enc = self.ema_concept_text_enc[prompt_ids]
124+
ema_superclass_text_enc = self.ema_superclass_text_enc[prompt_ids]
125+
126+
# if any in the batch is not initialized, initialize
127+
128+
if not all_initted:
129+
ema_concept_text_enc = torch.where(
130+
initted,
131+
ema_concept_text_enc,
132+
concept_text_enc
133+
)
134+
135+
ema_superclass_text_enc = torch.where(
136+
initted,
137+
ema_superclass_text_enc,
138+
superclass_text_enc
139+
)
140+
141+
# exponential moving average of both concept and superclass
142+
143+
concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
144+
superclass_text_enc = ema_superclass_text_enc * decay + superclass_text_enc * (1. - decay)
145+
146+
# store
147+
148+
if not all_initted:
149+
self.initted[prompt_ids] = True
150+
self.ema_concept_text_enc[prompt_ids] = ema_concept_text_enc
151+
self.ema_superclass_text_enc[prompt_ids] = ema_superclass_text_enc
152+
153+
# take care of the output
154+
# for the keys, make sure to turn off gradients as it is 'locked'
155+
101156
superclass_text_enc_output = einsum('b i, o i -> b o', superclass_text_enc, weights)
102157

103158
if self.is_key_proj:
104159
superclass_text_enc_output = superclass_text_enc_output.detach()
105160

106-
# only during training do they exponentially smooth
107-
108-
if self.training:
109-
online_estimated_concept_enc = decay * superclass_text_enc + (1. - decay) * concept_text_enc
110-
else:
111-
online_estimated_concept_enc = concept_text_enc
112-
113161
# make it easier to match with paper
114162

115-
i, o, W = online_estimated_concept_enc, superclass_text_enc_output, weights
163+
i, o, W = concept_text_enc, superclass_text_enc_output, weights
116164

117165
# main contribution eq (3)
118166

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.7',
6+
version = '0.0.8',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)