Skip to content

Commit c32bbb6

Browse files
committed
add a function for extracting the necessary parameters for fine-tuning, directly off the stable diffusion (or any text-to-<modality>) instance being operated on
1 parent 4028b4e commit c32bbb6

File tree

5 files changed

+26
-7
lines changed

5 files changed

+26
-7
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ wrapped_to_values = Rank1EditModule(
4848

4949
text_enc = torch.randn(4, 77, 768) # regular input
5050
text_enc_with_superclass = torch.randn(4, 77, 768) # init_input in algorithm 1, for key-locking
51-
concept_indices = torch.randint(0, 77, (4,))
51+
concept_indices = torch.randint(0, 77, (4,)) # index where the concept or superclass concept token is in the sequence
52+
key_pad_mask = torch.ones(4, 77).bool()
5253

5354
keys = wrapped_to_keys(
5455
text_enc,
@@ -81,7 +82,7 @@ from torch import nn
8182

8283
from perfusion_pytorch import EmbeddingWrapper
8384

84-
embed = nn.Embedding(49407, 512) # open clip embedding, somewhere in the module tree of stable diffusion
85+
embed = nn.Embedding(49408, 512) # open clip embedding, somewhere in the module tree of stable diffusion
8586

8687
# wrap it, and will automatically create a new concept for learning, based on the superclass embed string
8788

perfusion_pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313

1414
from perfusion_pytorch.save_load import (
1515
save,
16-
load
16+
load,
17+
get_finetune_parameters
1718
)

perfusion_pytorch/embedding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
self.num_concepts = num_concepts
7272
self.concepts = nn.Parameter(torch.zeros(num_concepts, dim))
7373

74-
assert exists(superclass_embed_id) ^ exists(superclass_string), 'either superclass embed id is given, or the superclass string'
74+
assert not (exists(superclass_embed_id) and exists(superclass_string)), 'either superclass embed id is given, or the superclass string'
7575

7676
self.pad_id = tokenizer_pad_id
7777
self.tokenize = None
@@ -130,8 +130,8 @@ def forward(
130130
if not isinstance(concept_id, tuple):
131131
concept_id = (concept_id,)
132132

133-
assert len(concept_id) == 1, 'can only train or inference on single concepts if passing in list of superclass strings'
134-
assert self.num_concepts == 1
133+
assert not self.training or len(concept_id) == 1, 'can only train or inference on single concepts if passing in list of superclass strings'
134+
assert not self.training or self.num_concepts == 1
135135

136136
if is_bearable(x, List[str]):
137137
inferred_concept_id = self.concept_embed_ids[0]
@@ -221,6 +221,8 @@ def merge_embedding_wrappers(
221221
num_concepts = total_concepts
222222
)
223223

224+
merged_concepts.eval()
225+
224226
concepts = torch.cat(tuple(embed.concepts.data for embed in embeds), dim = 0)
225227

226228
merged_concepts.concepts = nn.Parameter(concepts)

perfusion_pytorch/save_load.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,24 @@
99
from perfusion_pytorch.embedding import EmbeddingWrapper
1010
from perfusion_pytorch.perfusion import Rank1EditModule
1111

12+
# helper functions
13+
1214
def exists(val):
1315
return val is not None
1416

17+
# function that automatically finds all the parameters necessary for fine tuning
18+
19+
@beartype
20+
def get_finetune_parameters(text_image_model: Module):
21+
params = []
22+
for module in text_image_model.modules():
23+
if isinstance(module, (EmbeddingWrapper, Rank1EditModule)):
24+
params.extend(module.parameters())
25+
26+
return params
27+
28+
# saving and loading the necessary extra finetuned params
29+
1530
@beartype
1631
def save(
1732
text_image_model: Module,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.11',
6+
version = '0.1.12',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)