Skip to content

Commit f095f57

Browse files
committed
more back and forth with author through emails
1 parent bfbad05 commit f095f57

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def __init__(
5757
# will smooth both concept and superclass token inputs
5858

5959
self.register_buffer('initted', torch.zeros(num_finetune_prompts).bool())
60-
self.register_buffer('ema_concept_text_enc', torch.zeros(num_finetune_prompts, dim_input))
61-
self.register_buffer('ema_superclass_text_enc', torch.zeros(num_finetune_prompts, dim_input))
60+
self.register_buffer('ema_concept_text_encs', torch.zeros(num_finetune_prompts, dim_input))
61+
self.register_buffer('ema_superclass_text_encs', torch.zeros(num_finetune_prompts, dim_input))
62+
self.register_buffer('superclass_outputs', torch.zeros(num_finetune_prompts, dim_output))
6263

6364
# C in the paper, inverse precomputed
6465

@@ -109,6 +110,8 @@ def forward(
109110
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
110111
superclass_text_enc = rearrange(superclass_text_enc, 'b 1 d -> b d')
111112

113+
superclass_output = einsum('b i, o i -> b o', superclass_text_enc, weights)
114+
112115
# only if training, and if prompt ids are given
113116
# do exponential smoothing of the inputs, both concept and superclass
114117

@@ -120,8 +123,13 @@ def forward(
120123
initted = rearrange(initted, 'b -> b 1')
121124
all_initted = initted.all()
122125

123-
ema_concept_text_enc = self.ema_concept_text_enc[prompt_ids]
124-
ema_superclass_text_enc = self.ema_superclass_text_enc[prompt_ids]
126+
ema_concept_text_enc = self.ema_concept_text_encs[prompt_ids]
127+
ema_superclass_text_enc = self.ema_superclass_text_encs[prompt_ids]
128+
129+
# for keys, the superclass output (o*) is stored on init
130+
# and never optimized
131+
132+
stored_superclass_output = self.superclass_outputs[prompt_ids]
125133

126134
# if any in the batch is not initialized, initialize
127135

@@ -138,6 +146,12 @@ def forward(
138146
superclass_text_enc
139147
)
140148

149+
superclass_output = torch.where(
150+
initted,
151+
stored_superclass_output,
152+
superclass_output
153+
)
154+
141155
# exponential moving average of both concept and superclass
142156

143157
concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
@@ -147,20 +161,19 @@ def forward(
147161

148162
if not all_initted:
149163
self.initted[prompt_ids] = True
150-
self.ema_concept_text_enc[prompt_ids] = ema_concept_text_enc
151-
self.ema_superclass_text_enc[prompt_ids] = ema_superclass_text_enc
164+
self.ema_concept_text_encs[prompt_ids] = ema_concept_text_enc
165+
self.ema_superclass_text_encs[prompt_ids] = ema_superclass_text_enc
166+
self.superclass_outputs[prompt_ids] = superclass_output
152167

153168
# take care of the output
154169
# for the keys, make sure to turn off gradients as it is 'locked'
155170

156-
superclass_text_enc_output = einsum('b i, o i -> b o', superclass_text_enc, weights)
157-
158171
if self.is_key_proj:
159-
superclass_text_enc_output = superclass_text_enc_output.detach()
172+
superclass_output = superclass_output.detach()
160173

161174
# make it easier to match with paper
162175

163-
i, o, W = concept_text_enc, superclass_text_enc_output, weights
176+
i, o, W = concept_text_enc, superclass_output, weights
164177

165178
# main contribution eq (3)
166179

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.8',
6+
version = '0.0.9',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)