Skip to content

Commit ad968a9

Browse files
committed
handle the training of multiple concepts; save the inference part for later this week
1 parent c9adf6d commit ad968a9

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ values = wrapped_to_values(
7878

7979
## Todo
8080

81-
- [ ] handle rank-1 update for multiple concepts
81+
- [] handle rank-1 update for multiple concepts
82+
- [x] handle training with multiple concepts
83+
- [ ] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
8284

8385
- [x] take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update
8486
- [x] instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)

perfusion_pytorch/perfusion.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090
self,
9191
key_or_values_proj: nn.Linear,
9292
*,
93+
num_concepts: int = 1,
9394
C: Tensor, # covariance of input, precomputed from 100K laion text
9495
text_seq_len: int = 77,
9596
is_key_proj: bool = False,
@@ -103,6 +104,8 @@ def __init__(
103104
super().__init__()
104105
assert not exists(key_or_values_proj.bias), 'key value projection in attention should not have bias'
105106

107+
self.num_concepts = num_concepts
108+
106109
self.weight = key_or_values_proj.weight
107110
dim_output, dim_input = self.weight.shape
108111

@@ -123,14 +126,14 @@ def __init__(
123126
# for exponentially smoothing the inputs
124127
# will smooth both concept and superclass token inputs
125128

126-
self.register_buffer('initted', Tensor([False]))
127-
self.register_buffer('ema_concept_text_encs', torch.zeros(dim_input))
129+
self.register_buffer('initted', torch.zeros(num_concepts, 1).bool())
130+
self.register_buffer('ema_concept_text_encs', torch.zeros(num_concepts, dim_input))
128131

129132
# superclass outputs - only optimized for values, but not keys
130133

131134
self.is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
132135

133-
self.superclass_output = nn.Parameter(torch.zeros(dim_output), requires_grad = not is_key_proj)
136+
self.superclass_output = nn.Parameter(torch.zeros(num_concepts, dim_output), requires_grad = not is_key_proj)
134137

135138
# C in the paper, inverse precomputed
136139

@@ -147,7 +150,8 @@ def forward(
147150
self,
148151
text_enc: FloatTensor,
149152
concept_indices: IndicesTensor,
150-
text_enc_with_superclass: Optional[FloatTensor] = None
153+
text_enc_with_superclass: Optional[FloatTensor] = None,
154+
concept_id: int = 0
151155
):
152156
assert text_enc.shape[-2] == self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}'
153157

@@ -194,7 +198,9 @@ def forward(
194198

195199
# get the initialization state
196200

197-
initted = self.initted.item()
201+
assert concept_id < self.num_concepts
202+
203+
initted = self.initted[concept_id].item()
198204

199205
if self.training:
200206
# store the superclass i* if not all initialized
@@ -204,21 +210,21 @@ def forward(
204210
assert exists(superclass_output), 'text_enc_with_superclass must be passed in for the first batch'
205211

206212
# for the prompt ids not initialized yet, hard copy over the initial superclass outputs
207-
self.superclass_output.data.copy_(superclass_output)
213+
self.superclass_output[concept_id].data.copy_(superclass_output)
208214

209215
elif exists(superclass_output):
210216
# if text enc with superclass is passed in for more than 1 batch
211217
# just take the opportunity to exponentially average it a bit more
212218

213219
ema_superclass_output = self.superclass_output * decay + superclass_output * (1. - decay)
214-
self.superclass_output.data.copy_(ema_superclass_output)
220+
self.superclass_output[concept_id].data.copy_(ema_superclass_output)
215221

216222
# if any in the batch is not initialized, initialize
217223

218224
if not initted:
219225
ema_concept_text_enc = concept_text_enc
220226
else:
221-
ema_concept_text_enc = self.ema_concept_text_enc
227+
ema_concept_text_enc = self.ema_concept_text_enc[concept_id]
222228

223229
# exponential moving average for concept input encoding
224230

@@ -227,14 +233,14 @@ def forward(
227233
# store
228234

229235
if not initted:
230-
self.initted.data.copy_(Tensor([True]))
231-
self.ema_concept_text_encs.data.copy_(ema_concept_text_enc)
236+
self.initted[concept_id].data.copy_(Tensor([True]))
237+
self.ema_concept_text_encs[concept_id].data.copy_(ema_concept_text_enc)
232238
else:
233239
assert initted, 'you have not initialized or trained this module yet'
234240

235241
# make it easier to match with paper
236242

237-
i, o, W = concept_text_enc, self.superclass_output, weights
243+
i, o, W = concept_text_enc, self.superclass_output[concept_id], weights
238244

239245
# main contribution eq (3)
240246

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.19',
6+
version = '0.0.20',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)