Skip to content

Commit c1079f0

Browse files
dsoceksayakpaul
andauthored
Fix textual inversion SDXL and add support for 2nd text encoder (#9010)
* Fix textual inversion SDXL and add support for 2nd text encoder Signed-off-by: Daniel Socek <[email protected]> * Fix style/quality of text inv for sdxl Signed-off-by: Daniel Socek <[email protected]> --------- Signed-off-by: Daniel Socek <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 65e3090 commit c1079f0

File tree

2 files changed

+81
-12
lines changed

2 files changed

+81
-12
lines changed

examples/textual_inversion/README_sdxl.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,25 @@ accelerate launch textual_inversion_sdxl.py \
2323
--output_dir="./textual_inversion_cat_sdxl"
2424
```
2525

26-
For now, only training of the first text encoder is supported.
26+
Training of both text encoders is supported.
27+
28+
### Inference Example
29+
30+
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`.
31+
Make sure to include the `placeholder_token` in your prompt.
32+
33+
```python
34+
from diffusers import StableDiffusionXLPipeline
35+
import torch
36+
37+
model_id = "./textual_inversion_cat_sdxl"
38+
pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
39+
40+
prompt = "A <cat-toy> backpack"
41+
42+
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
43+
image.save("cat-backpack.png")
44+
45+
image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
46+
image.save("cat-backpack-prompt_2.png")
47+
```

examples/textual_inversion/textual_inversion_sdxl.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def log_validation(
135135
pipeline = DiffusionPipeline.from_pretrained(
136136
args.pretrained_model_name_or_path,
137137
text_encoder=accelerator.unwrap_model(text_encoder_1),
138-
text_encoder_2=text_encoder_2,
138+
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
139139
tokenizer=tokenizer_1,
140140
tokenizer_2=tokenizer_2,
141141
unet=unet,
@@ -678,36 +678,54 @@ def main():
678678
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
679679
" `placeholder_token` that is not already in the tokenizer."
680680
)
681+
num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)
682+
if num_added_tokens != args.num_vectors:
683+
raise ValueError(
684+
f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different"
685+
" `placeholder_token` that is not already in the tokenizer."
686+
)
681687

682688
# Convert the initializer_token, placeholder_token to ids
683689
token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)
690+
token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)
691+
684692
# Check if initializer_token is a single token or a sequence of tokens
685-
if len(token_ids) > 1:
693+
if len(token_ids) > 1 or len(token_ids_2) > 1:
686694
raise ValueError("The initializer token must be a single token.")
687695

688696
initializer_token_id = token_ids[0]
689697
placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)
698+
initializer_token_id_2 = token_ids_2[0]
699+
placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
690700

691701
# Resize the token embeddings as we are adding new special tokens to the tokenizer
692702
text_encoder_1.resize_token_embeddings(len(tokenizer_1))
703+
text_encoder_2.resize_token_embeddings(len(tokenizer_2))
693704

694705
# Initialise the newly added placeholder token with the embeddings of the initializer token
695706
token_embeds = text_encoder_1.get_input_embeddings().weight.data
707+
token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data
696708
with torch.no_grad():
697709
for token_id in placeholder_token_ids:
698710
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
711+
for token_id in placeholder_token_ids_2:
712+
token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()
699713

700714
# Freeze vae and unet
701715
vae.requires_grad_(False)
702716
unet.requires_grad_(False)
703-
text_encoder_2.requires_grad_(False)
717+
704718
# Freeze all parameters except for the token embeddings in text encoder
705719
text_encoder_1.text_model.encoder.requires_grad_(False)
706720
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
707721
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
722+
text_encoder_2.text_model.encoder.requires_grad_(False)
723+
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
724+
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
708725

709726
if args.gradient_checkpointing:
710727
text_encoder_1.gradient_checkpointing_enable()
728+
text_encoder_2.gradient_checkpointing_enable()
711729

712730
if args.enable_xformers_memory_efficient_attention:
713731
if is_xformers_available():
@@ -746,7 +764,11 @@ def main():
746764
optimizer_class = torch.optim.AdamW
747765

748766
optimizer = optimizer_class(
749-
text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings
767+
# only optimize the embeddings
768+
[
769+
text_encoder_1.text_model.embeddings.token_embedding.weight,
770+
text_encoder_2.text_model.embeddings.token_embedding.weight,
771+
],
750772
lr=args.learning_rate,
751773
betas=(args.adam_beta1, args.adam_beta2),
752774
weight_decay=args.adam_weight_decay,
@@ -786,9 +808,10 @@ def main():
786808
)
787809

788810
text_encoder_1.train()
811+
text_encoder_2.train()
789812
# Prepare everything with our `accelerator`.
790-
text_encoder_1, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
791-
text_encoder_1, optimizer, train_dataloader, lr_scheduler
813+
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
814+
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
792815
)
793816

794817
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
@@ -866,11 +889,13 @@ def main():
866889

867890
# keep original embeddings as reference
868891
orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
892+
orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
869893

870894
for epoch in range(first_epoch, args.num_train_epochs):
871895
text_encoder_1.train()
896+
text_encoder_2.train()
872897
for step, batch in enumerate(train_dataloader):
873-
with accelerator.accumulate(text_encoder_1):
898+
with accelerator.accumulate([text_encoder_1, text_encoder_2]):
874899
# Convert images to latent space
875900
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
876901
latents = latents * vae.config.scaling_factor
@@ -892,9 +917,7 @@ def main():
892917
.hidden_states[-2]
893918
.to(dtype=weight_dtype)
894919
)
895-
encoder_output_2 = text_encoder_2(
896-
batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True
897-
)
920+
encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True)
898921
encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
899922
original_size = [
900923
(batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
@@ -938,11 +961,16 @@ def main():
938961
# Let's make sure we don't update any embedding weights besides the newly added token
939962
index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
940963
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
964+
index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
965+
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
941966

942967
with torch.no_grad():
943968
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
944969
index_no_updates
945970
] = orig_embeds_params[index_no_updates]
971+
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
972+
index_no_updates_2
973+
] = orig_embeds_params_2[index_no_updates_2]
946974

947975
# Checks if the accelerator has performed an optimization step behind the scenes
948976
if accelerator.sync_gradients:
@@ -960,6 +988,16 @@ def main():
960988
save_path,
961989
safe_serialization=True,
962990
)
991+
weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
992+
save_path = os.path.join(args.output_dir, weight_name)
993+
save_progress(
994+
text_encoder_2,
995+
placeholder_token_ids_2,
996+
accelerator,
997+
args,
998+
save_path,
999+
safe_serialization=True,
1000+
)
9631001

9641002
if accelerator.is_main_process:
9651003
if global_step % args.checkpointing_steps == 0:
@@ -1034,7 +1072,7 @@ def main():
10341072
pipeline = DiffusionPipeline.from_pretrained(
10351073
args.pretrained_model_name_or_path,
10361074
text_encoder=accelerator.unwrap_model(text_encoder_1),
1037-
text_encoder_2=text_encoder_2,
1075+
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
10381076
vae=vae,
10391077
unet=unet,
10401078
tokenizer=tokenizer_1,
@@ -1052,6 +1090,16 @@ def main():
10521090
save_path,
10531091
safe_serialization=True,
10541092
)
1093+
weight_name = "learned_embeds_2.safetensors"
1094+
save_path = os.path.join(args.output_dir, weight_name)
1095+
save_progress(
1096+
text_encoder_2,
1097+
placeholder_token_ids_2,
1098+
accelerator,
1099+
args,
1100+
save_path,
1101+
safe_serialization=True,
1102+
)
10551103

10561104
if args.push_to_hub:
10571105
save_model_card(

0 commit comments

Comments
 (0)