Skip to content

Commit 7527f48

Browse files
committed
another round of feedbacks, update todos
1 parent 548d4f9 commit 7527f48

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ values = wrapped_to_values(
8080

8181
## Todo
8282

83+
- [ ] add the zero-shot masking of concept proposed in paper
84+
- [ ] offer a way to combine separately learned concepts from multiple `Rank1EditModule` into one for inference
8385
- [ ] handle rank-1 update for multiple concepts
8486
- [x] handle training with multiple concepts
8587
- [ ] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs

perfusion_pytorch/perfusion.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def __init__(
129129
self.register_buffer('initted', torch.zeros(num_concepts, 1).bool())
130130
self.register_buffer('ema_concept_text_encs', torch.zeros(num_concepts, dim_input))
131131

132-
# superclass outputs - only optimized for values, but not keys
132+
# concept outputs - only optimized for values, but not keys
133133

134134
self.is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
135135

136-
self.superclass_output = nn.Parameter(torch.zeros(num_concepts, dim_output), requires_grad = not is_key_proj)
136+
self.concept_output = nn.Parameter(torch.zeros(num_concepts, dim_output), requires_grad = not is_key_proj)
137137

138138
# C in the paper, inverse precomputed
139139

@@ -143,7 +143,7 @@ def parameters(self):
143143
if not self.is_key_proj:
144144
return []
145145

146-
return [self.superclass_outputs]
146+
return [self.concept_output]
147147

148148
@beartype
149149
def forward(
@@ -209,15 +209,15 @@ def forward(
209209
if not initted:
210210
assert exists(superclass_output), 'text_enc_with_superclass must be passed in for the first batch'
211211

212-
# for the prompt ids not initialized yet, hard copy over the initial superclass outputs
213-
self.superclass_output[concept_id].data.copy_(superclass_output)
212+
# init concept output with superclass output - fixed for keys, learned for values
213+
self.concept_output[concept_id].data.copy_(superclass_output)
214214

215-
elif exists(superclass_output):
215+
elif exists(superclass_output) and self.is_key_proj:
216216
# if text enc with superclass is passed in for more than 1 batch
217-
# just take the opportunity to exponentially average it a bit more
217+
# just take the opportunity to exponentially average it a bit more for the keys, which have fixed concept output (to superclass)
218218

219-
ema_superclass_output = self.superclass_output * decay + superclass_output * (1. - decay)
220-
self.superclass_output[concept_id].data.copy_(ema_superclass_output)
219+
ema_concept_output = self.concept_output * decay + superclass_output * (1. - decay)
220+
self.concept_output[concept_id].data.copy_(ema_concept_output)
221221

222222
# if any in the batch is not initialized, initialize
223223

@@ -234,13 +234,13 @@ def forward(
234234

235235
if not initted:
236236
self.initted[concept_id].data.copy_(Tensor([True]))
237-
self.ema_concept_text_encs[concept_id].data.copy_(ema_concept_text_enc)
237+
self.ema_concept_text_encs[concept_id].data.copy_(concept_text_enc)
238238
else:
239239
assert initted, 'you have not initialized or trained this module yet'
240240

241241
# make it easier to match with paper
242242

243-
i, o, W = concept_text_enc, self.superclass_output[concept_id], weights
243+
i, o, W = self.ema_concept_text_encs[concept_id], self.concept_output[concept_id], weights
244244

245245
# main contribution eq (3)
246246

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.20',
6+
version = '0.0.21',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)