1- from ldm .modules .encoders .modules import BERTTokenizer
1+ from ldm .modules .encoders .modules import FrozenCLIPEmbedder , BERTEmbedder
22from ldm .modules .embedding_manager import EmbeddingManager
33
44import argparse , os
55from functools import partial
66
77import torch
88
9- def get_placeholder_loop (placeholder_string , tokenizer ):
9+ def get_placeholder_loop (placeholder_string , embedder , use_bert ):
1010
1111 new_placeholder = None
1212
@@ -16,10 +16,36 @@ def get_placeholder_loop(placeholder_string, tokenizer):
1616 else :
1717 new_placeholder = input (f"Placeholder string '{ new_placeholder } ' maps to more than a single token. Please enter another string: " )
1818
19- token = tokenizer (new_placeholder )
19+ token = get_bert_token_for_string (embedder .tknz_fn , new_placeholder ) if use_bert else get_clip_token_for_string (embedder .tokenizer , new_placeholder )
20+
21+ if token is not None :
22+ return new_placeholder , token
23+
24+ def get_clip_token_for_string (tokenizer , string ):
25+ batch_encoding = tokenizer (
26+ string ,
27+ truncation = True ,
28+ max_length = 77 ,
29+ return_length = True ,
30+ return_overflowing_tokens = False ,
31+ padding = "max_length" ,
32+ return_tensors = "pt"
33+ )
34+
35+ tokens = batch_encoding ["input_ids" ]
36+
37+ if torch .count_nonzero (tokens - 49407 ) == 2 :
38+ return tokens [0 , 1 ]
39+
40+ return None
41+
42+ def get_bert_token_for_string (tokenizer , string ):
43+ token = tokenizer (string )
44+ if torch .count_nonzero (token ) == 3 :
45+ return token [0 , 1 ]
46+
47+ return None
2048
21- if torch .count_nonzero (token ) == 3 :
22- return new_placeholder , token [0 , 1 ]
2349
2450if __name__ == "__main__" :
2551
@@ -40,10 +66,20 @@ def get_placeholder_loop(placeholder_string, tokenizer):
4066 help = "Output path for the merged manager" ,
4167 )
4268
69+ parser .add_argument (
70+ "-sd" , "--use_bert" ,
71+ action = "store_true" ,
72+ help = "Flag to denote that we are not merging stable diffusion embeddings"
73+ )
74+
4375 args = parser .parse_args ()
4476
45- tokenizer = BERTTokenizer (vq_interface = False , max_length = 77 )
46- EmbeddingManager = partial (EmbeddingManager , tokenizer , ["*" ])
77+ if args .use_bert :
78+ embedder = BERTEmbedder (n_embed = 1280 , n_layer = 32 ).cuda ()
79+ else :
80+ embedder = FrozenCLIPEmbedder ().cuda ()
81+
82+ EmbeddingManager = partial (EmbeddingManager , embedder , ["*" ])
4783
4884 string_to_token_dict = {}
4985 string_to_param_dict = torch .nn .ParameterDict ()
@@ -63,7 +99,7 @@ def get_placeholder_loop(placeholder_string, tokenizer):
6399
64100 placeholder_to_src [placeholder_string ] = manager_ckpt
65101 else :
66- new_placeholder , new_token = get_placeholder_loop (placeholder_string , tokenizer )
102+ new_placeholder , new_token = get_placeholder_loop (placeholder_string , embedder , use_bert = args . use_bert )
67103 string_to_token_dict [new_placeholder ] = new_token
68104 string_to_param_dict [new_placeholder ] = manager .string_to_param_dict [placeholder_string ]
69105
@@ -77,7 +113,3 @@ def get_placeholder_loop(placeholder_string, tokenizer):
77113
78114 print ("Managers merged. Final list of placeholders: " )
79115 print (placeholder_to_src )
80-
81-
82-
83-
0 commit comments