Skip to content

Commit e7ac291

Browse files
committed
init_input is actually the prompt with the superclass. this will be hard to do in a modular fashion, may have to directly modify stable diffusion so that text encoder receives the two prompts and pass them directly to all the cross attention modules
1 parent 8c9b935 commit e7ac291

File tree

2 files changed

+12
-43
lines changed

2 files changed

+12
-43
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,17 @@ def __init__(
4545

4646
self.input_decay = input_decay
4747

48-
# they exponentially smooth the text encoding inputs during training
49-
# in addition to a lowered learning rate on the text encodings
50-
5148
self.text_seq_len = text_seq_len
5249

53-
self.register_buffer('initted', torch.zeros(num_finetune_prompts).bool())
54-
self.register_buffer('ema_concept_text_enc', torch.zeros(num_finetune_prompts, dim_input))
55-
self.register_buffer('outputs', torch.zeros(num_finetune_prompts, text_seq_len, dim_output))
56-
5750
# C in the paper, inverse precomputed
5851

5952
self.register_buffer('C_inv', torch.inverse(C))
6053

6154
@beartype
6255
def forward(
6356
self,
64-
prompt_ids: Tensor,
6557
text_enc: Tensor,
58+
text_enc_with_superclass: Tensor,
6659
concept_indices: Tensor
6760
):
6861
assert text_enc.shape[-2] == self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}'
@@ -78,7 +71,7 @@ def forward(
7871
o - output dimension
7972
"""
8073

81-
batch, device = text_enc.shape[0], self.initted.device
74+
batch, device = text_enc.shape[0], self.C_inv.device
8275

8376
weights, decay, Ci = self.weight, self.input_decay, self.C_inv
8477

@@ -95,46 +88,22 @@ def forward(
9588
concept_text_enc = text_enc[batch_indices, concept_indices]
9689
concept_text_enc = rearrange(concept_text_enc, 'b 1 d -> b d')
9790

98-
# during training, keep track of exponentially smoothed input
99-
100-
if self.training:
101-
102-
# get the initialization state
103-
# as well as the exponentially smoothed text encodings
104-
105-
initted = self.initted[prompt_ids]
106-
all_initted = initted.all()
107-
108-
ema_concept_text_enc = self.ema_concept_text_enc[prompt_ids]
109-
outputs = self.outputs[prompt_ids]
110-
111-
# if any in the batch is not initialized, initialize
112-
113-
if not all_initted:
114-
ema_concept_text_enc = torch.where(
115-
rearrange(initted, 'b -> b 1'),
116-
ema_concept_text_enc,
117-
concept_text_enc
118-
)
91+
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
92+
superclass_text_enc = rearrange(superclass_text_enc, 'b 1 d -> b d')
11993

120-
outputs = torch.where(
121-
rearrange(initted, 'b -> b 1 1'),
122-
outputs,
123-
einsum('o i, b n i -> b n o', weights, text_enc)
124-
)
94+
# take care of initializing with superclass prompt
95+
# for key-locking - this assumes stable diffusion was modified so text encoder takes in a prompt with both the <concept> as well as <superclass> - it seems this also has the limitation that <superclass> must be one token
12596

126-
# update using exponential moving average
97+
text_enc_with_superclass_output = einsum('b n i, o i -> b n o', text_enc_with_superclass, weights)
12798

128-
concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
99+
if self.is_key_proj:
100+
text_enc_with_superclass_output = text_enc_with_superclass_output.detach()
129101

130-
if not all_initted:
131-
self.initted[prompt_ids] = True
132-
self.ema_concept_text_enc[prompt_ids] = ema_concept_text_enc
133-
self.outputs[prompt_ids] = outputs
102+
online_estimated_concept_enc = decay * superclass_text_enc + (1. - decay) * concept_text_enc
134103

135104
# make it easier to match with paper
136105

137-
i, o, W = ema_concept_text_enc, outputs, weights
106+
i, o, W = online_estimated_concept_enc, text_enc_with_superclass_output, weights
138107

139108
# main contribution eq (3)
140109

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.2',
6+
version = '0.0.3',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)