Skip to content

Commit 86d1896

Browse files
committed
account for more feedback
1 parent 794669d commit 86d1896

File tree

2 files changed

+28
-35
lines changed

2 files changed

+28
-35
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ def __init__(
7272
self.weight = key_or_values_proj.weight
7373
dim_output, dim_input = self.weight.shape
7474

75-
self.is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
76-
7775
self.train_beta = train_beta
7876
self.train_temperature = train_temperature
7977
self.eval_beta = eval_beta
@@ -88,13 +86,23 @@ def __init__(
8886

8987
self.register_buffer('initted', torch.zeros(num_finetune_prompts).bool())
9088
self.register_buffer('ema_concept_text_encs', torch.zeros(num_finetune_prompts, dim_input))
91-
self.register_buffer('superclass_text_encs', torch.zeros(num_finetune_prompts, dim_input))
92-
self.register_buffer('superclass_outputs', torch.zeros(num_finetune_prompts, dim_output))
89+
90+
# superclass outputs - only optimized for values, but not keys
91+
92+
self.is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
93+
94+
self.superclass_outputs = nn.Parameter(torch.zeros(num_finetune_prompts, dim_output), requires_grad = not is_key_proj)
9395

9496
# C in the paper, inverse precomputed
9597

9698
self.register_buffer('C_inv', torch.inverse(C))
9799

100+
def parameters(self):
101+
if not self.is_key_proj:
102+
return []
103+
104+
return [self.superclass_outputs]
105+
98106
@beartype
99107
def forward(
100108
self,
@@ -134,59 +142,46 @@ def forward(
134142
concept_text_enc = text_enc[batch_indices, concept_indices]
135143
concept_text_enc = rearrange(concept_text_enc, 'b 1 d -> b d')
136144

137-
# take care of initializing with superclass prompt
138-
# for key-locking - this assumes stable diffusion was modified so text encoder takes in a prompt with both the <concept> as well as <superclass> - it seems this also has the limitation that <superclass> must be one token
139-
140-
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
141-
superclass_text_enc = rearrange(superclass_text_enc, 'b 1 d -> b d')
142-
143-
superclass_output = einsum('b i, o i -> b o', superclass_text_enc, weights)
144-
145145
# only if training, and if prompt ids are given
146146
# do exponential smoothing of the inputs, both concept and superclass
147147

148+
if exists(text_enc_with_superclass):
149+
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
150+
superclass_text_enc = rearrange(superclass_text_enc, 'b 1 d -> b d')
151+
152+
superclass_output = einsum('b i, o i -> b o', superclass_text_enc, weights)
153+
148154
if self.training and exists(prompt_ids):
149155
# get the initialization state
150156
# as well as the exponentially smoothed text encodings
151157

152158
initted = self.initted[prompt_ids]
153-
initted = rearrange(initted, 'b -> b 1')
154159
all_initted = initted.all()
155160

156161
ema_concept_text_enc = self.ema_concept_text_encs[prompt_ids]
157162

158-
# fetch superclass
163+
# store the superclass i* if not all initialized
164+
# else fetch it from the buffer
159165

160-
assert exists(superclass_text_enc) or all_initted
166+
if not all_initted:
167+
assert exists(superclass_output), 'text_enc_with_superclass must be passed in for the first epoch for all prompts to initialize the module correctly'
161168

162-
stored_superclass_text_enc = self.superclass_text_encs[prompt_ids]
169+
non_initted_prompt_ids = prompt_ids[~initted]
163170

164-
# for keys, the superclass output (o*) is stored on init
165-
# and never optimized
171+
# for the prompt ids not initialized yet, hard copy over the initial superclass outputs
172+
self.superclass_outputs[non_initted_prompt_ids].data.copy_(superclass_output)
166173

167-
stored_superclass_output = self.superclass_outputs[prompt_ids]
174+
superclass_output = self.superclass_outputs[prompt_ids]
168175

169176
# if any in the batch is not initialized, initialize
170177

171178
if not all_initted:
172179
ema_concept_text_enc = torch.where(
173-
initted,
180+
rearrange(initted, 'b -> b 1'),
174181
ema_concept_text_enc,
175182
concept_text_enc
176183
)
177184

178-
superclass_text_enc = torch.where(
179-
initted,
180-
stored_superclass_text_enc,
181-
superclass_text_enc
182-
)
183-
184-
superclass_output = torch.where(
185-
initted,
186-
stored_superclass_output,
187-
superclass_output
188-
)
189-
190185
# exponential moving average for concept input encoding
191186

192187
concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
@@ -196,8 +191,6 @@ def forward(
196191
if not all_initted:
197192
self.initted[prompt_ids] = True
198193
self.ema_concept_text_encs[prompt_ids] = ema_concept_text_enc
199-
self.superclass_text_encs[prompt_ids] = superclass_text_enc
200-
self.superclass_outputs[prompt_ids] = superclass_output
201194

202195
# take care of the output
203196
# for the keys, make sure to turn off gradients as it is 'locked'

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.11',
6+
version = '0.0.12',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)