Skip to content

Commit b42836f

Browse files
committed
take into account some feedback in email with author
1 parent 463c918 commit b42836f

File tree

3 files changed

+15
-16
lines changed

3 files changed

+15
-16
lines changed

README.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ import torch
2323
from perfusion_pytorch import Rank1EditModule
2424
from torch import nn
2525

26-
to_keys = nn.Linear(512, 1024, bias = False)
27-
to_values = nn.Linear(512, 1024, bias = False)
26+
to_keys = nn.Linear(768, 320, bias = False)
27+
to_values = nn.Linear(768, 320, bias = False)
2828

29-
input_covariance = torch.randn(512, 512)
29+
input_covariance = torch.randn(768, 768)
3030

3131
wrapped_to_keys = Rank1EditModule(
3232
to_keys,
@@ -41,10 +41,9 @@ wrapped_to_values = Rank1EditModule(
4141
num_finetune_prompts = 32
4242
)
4343

44-
text_enc = torch.randn(4, 256, 512) # regular input
45-
text_enc_with_superclass = torch.randn(4, 256, 512) # init_input in algorithm 1, for key-locking
46-
concept_ids = torch.randint(0, 256, (4,))
47-
44+
text_enc = torch.randn(4, 77, 768) # regular input
45+
text_enc_with_superclass = torch.randn(4, 77, 768) # init_input in algorithm 1, for key-locking
46+
concept_ids = torch.randint(0, 77, (4,))
4847

4948
keys = wrapped_to_keys(text_enc, text_enc_with_superclass, concept_ids)
5049
values = wrapped_to_values(text_enc, text_enc_with_superclass, concept_ids)

perfusion_pytorch/perfusion.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
*,
2727
num_finetune_prompts: int,
2828
C: Tensor, # covariance of input, precomputed from 100K laion text
29-
text_seq_len: int = 256,
29+
text_seq_len: int = 77,
3030
is_key_proj: bool = False,
3131
input_decay = 0.99,
3232
train_beta = 0.75,
@@ -92,16 +92,16 @@ def forward(
9292
concept_text_enc = text_enc[batch_indices, concept_indices]
9393
concept_text_enc = rearrange(concept_text_enc, 'b 1 d -> b d')
9494

95-
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
96-
superclass_text_enc = rearrange(superclass_text_enc, 'b 1 d -> b d')
97-
9895
# take care of initializing with superclass prompt
9996
# for key-locking - this assumes stable diffusion was modified so text encoder takes in a prompt with both the <concept> as well as <superclass> - it seems this also has the limitation that <superclass> must be one token
10097

101-
text_enc_with_superclass_output = einsum('b n i, o i -> b n o', text_enc_with_superclass, weights)
98+
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
99+
superclass_text_enc = rearrange(superclass_text_enc, 'b 1 d -> b d')
100+
101+
superclass_text_enc_output = einsum('b i, o i -> b o', superclass_text_enc, weights)
102102

103103
if self.is_key_proj:
104-
text_enc_with_superclass_output = text_enc_with_superclass_output.detach()
104+
superclass_text_enc_output = superclass_text_enc_output.detach()
105105

106106
# only during training do they exponentially smooth
107107

@@ -112,7 +112,7 @@ def forward(
112112

113113
# make it easier to match with paper
114114

115-
i, o, W = online_estimated_concept_enc, text_enc_with_superclass_output, weights
115+
i, o, W = online_estimated_concept_enc, superclass_text_enc_output, weights
116116

117117
# main contribution eq (3)
118118

@@ -131,4 +131,4 @@ def forward(
131131

132132
W_em_orthogonal_term = text_enc_output - (sim * concept_output / i_energy)
133133

134-
return W_em_orthogonal_term + sigmoid_term * o
134+
return W_em_orthogonal_term + sigmoid_term * rearrange(o, 'b d -> b 1 d')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.6',
6+
version = '0.0.7',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)