Skip to content

Commit 58e3562

Browse files
authored
Fix merging embeddings (#226)
Fixed merging embeddings based on the changes made in textual inversion. Tested and working. Inverted their logic to prioritize Stable Diffusion implementation over alternatives, but left the option for alternatives to still be used.
1 parent b622819 commit 58e3562

File tree

2 files changed

+48
-15
lines changed

2 files changed

+48
-15
lines changed

configs/stable-diffusion/v1-finetune.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ model:
5252
ddconfig:
5353
double_z: true
5454
z_channels: 4
55-
resolution: 512
55+
resolution: 256
5656
in_channels: 3
5757
out_ch: 3
5858
ch: 128
@@ -74,7 +74,7 @@ data:
7474
target: main.DataModuleFromConfig
7575
params:
7676
batch_size: 1
77-
num_workers: 16
77+
num_workers: 2
7878
wrap: false
7979
train:
8080
target: ldm.data.personalized.PersonalizedBase
@@ -105,4 +105,5 @@ lightning:
105105

106106
trainer:
107107
benchmark: True
108-
max_steps: 6100
108+
max_steps: 4000
109+

scripts/merge_embeddings.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from ldm.modules.encoders.modules import BERTTokenizer
1+
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
22
from ldm.modules.embedding_manager import EmbeddingManager
33

44
import argparse, os
55
from functools import partial
66

77
import torch
88

9-
def get_placeholder_loop(placeholder_string, tokenizer):
9+
def get_placeholder_loop(placeholder_string, embedder, use_bert):
1010

1111
new_placeholder = None
1212

@@ -16,10 +16,36 @@ def get_placeholder_loop(placeholder_string, tokenizer):
1616
else:
1717
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
1818

19-
token = tokenizer(new_placeholder)
19+
token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
20+
21+
if token is not None:
22+
return new_placeholder, token
23+
24+
def get_clip_token_for_string(tokenizer, string):
25+
batch_encoding = tokenizer(
26+
string,
27+
truncation=True,
28+
max_length=77,
29+
return_length=True,
30+
return_overflowing_tokens=False,
31+
padding="max_length",
32+
return_tensors="pt"
33+
)
34+
35+
tokens = batch_encoding["input_ids"]
36+
37+
if torch.count_nonzero(tokens - 49407) == 2:
38+
return tokens[0, 1]
39+
40+
return None
41+
42+
def get_bert_token_for_string(tokenizer, string):
43+
token = tokenizer(string)
44+
if torch.count_nonzero(token) == 3:
45+
return token[0, 1]
46+
47+
return None
2048

21-
if torch.count_nonzero(token) == 3:
22-
return new_placeholder, token[0, 1]
2349

2450
if __name__ == "__main__":
2551

@@ -40,10 +66,20 @@ def get_placeholder_loop(placeholder_string, tokenizer):
4066
help="Output path for the merged manager",
4167
)
4268

69+
parser.add_argument(
70+
"-sd", "--use_bert",
71+
action="store_true",
72+
help="Flag to denote that we are not merging stable diffusion embeddings"
73+
)
74+
4375
args = parser.parse_args()
4476

45-
tokenizer = BERTTokenizer(vq_interface=False, max_length=77)
46-
EmbeddingManager = partial(EmbeddingManager, tokenizer, ["*"])
77+
if args.use_bert:
78+
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
79+
else:
80+
embedder = FrozenCLIPEmbedder().cuda()
81+
82+
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
4783

4884
string_to_token_dict = {}
4985
string_to_param_dict = torch.nn.ParameterDict()
@@ -63,7 +99,7 @@ def get_placeholder_loop(placeholder_string, tokenizer):
6399

64100
placeholder_to_src[placeholder_string] = manager_ckpt
65101
else:
66-
new_placeholder, new_token = get_placeholder_loop(placeholder_string, tokenizer)
102+
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, use_bert=args.use_bert)
67103
string_to_token_dict[new_placeholder] = new_token
68104
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
69105

@@ -77,7 +113,3 @@ def get_placeholder_loop(placeholder_string, tokenizer):
77113

78114
print("Managers merged. Final list of placeholders: ")
79115
print(placeholder_to_src)
80-
81-
82-
83-

0 commit comments

Comments
 (0)