Skip to content

Commit af92869

Browse files
asomozalinoytsabansayakpaul
authored
[SD3 LoRA Training] Fix errors when not training text encoders (#8743)
* fix * fix things. Co-authored-by: Linoy Tsaban <[email protected]> * remove patch * apply suggestions --------- Co-authored-by: Linoy Tsaban <[email protected]> Co-authored-by: sayakpaul <[email protected]> Co-authored-by: Linoy Tsaban <[email protected]>
1 parent 0bae6e4 commit af92869

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def encode_prompt(
962962
prompt=prompt,
963963
device=device if device is not None else text_encoder.device,
964964
num_images_per_prompt=num_images_per_prompt,
965-
text_input_ids=text_input_ids_list[i],
965+
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
966966
)
967967
clip_prompt_embeds_list.append(prompt_embeds)
968968
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
@@ -976,7 +976,7 @@ def encode_prompt(
976976
max_sequence_length,
977977
prompt=prompt,
978978
num_images_per_prompt=num_images_per_prompt,
979-
text_input_ids=text_input_ids_list[:-1],
979+
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
980980
device=device if device is not None else text_encoders[-1].device,
981981
)
982982

@@ -1491,6 +1491,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14911491
) = accelerator.prepare(
14921492
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
14931493
)
1494+
assert text_encoder_one is not None
1495+
assert text_encoder_two is not None
1496+
assert text_encoder_three is not None
14941497
else:
14951498
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
14961499
transformer, optimizer, train_dataloader, lr_scheduler
@@ -1598,7 +1601,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15981601
tokens_three = tokenize_prompt(tokenizer_three, prompts)
15991602
prompt_embeds, pooled_prompt_embeds = encode_prompt(
16001603
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
1601-
tokenizers=[None, None, tokenizer_three],
1604+
tokenizers=[None, None, None],
16021605
prompt=prompts,
16031606
max_sequence_length=args.max_sequence_length,
16041607
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
@@ -1608,7 +1611,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16081611
prompt_embeds, pooled_prompt_embeds = encode_prompt(
16091612
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
16101613
tokenizers=[None, None, tokenizer_three],
1611-
prompt=prompts,
1614+
prompt=args.instance_prompt,
16121615
max_sequence_length=args.max_sequence_length,
16131616
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
16141617
)
@@ -1685,10 +1688,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16851688

16861689
accelerator.backward(loss)
16871690
if accelerator.sync_gradients:
1688-
params_to_clip = itertools.chain(
1689-
transformer_lora_parameters,
1690-
text_lora_parameters_one,
1691-
text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters,
1691+
params_to_clip = (
1692+
itertools.chain(
1693+
transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
1694+
)
1695+
if args.train_text_encoder
1696+
else transformer_lora_parameters
16921697
)
16931698
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
16941699

@@ -1741,13 +1746,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17411746
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
17421747
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
17431748
)
1744-
else:
1745-
text_encoder_three = text_encoder_cls_three.from_pretrained(
1746-
args.pretrained_model_name_or_path,
1747-
subfolder="text_encoder_3",
1748-
revision=args.revision,
1749-
variant=args.variant,
1750-
)
17511749
pipeline = StableDiffusion3Pipeline.from_pretrained(
17521750
args.pretrained_model_name_or_path,
17531751
vae=vae,
@@ -1767,7 +1765,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17671765
pipeline_args=pipeline_args,
17681766
epoch=epoch,
17691767
)
1770-
del text_encoder_one, text_encoder_two, text_encoder_three
1768+
if not args.train_text_encoder:
1769+
del text_encoder_one, text_encoder_two, text_encoder_three
1770+
17711771
torch.cuda.empty_cache()
17721772
gc.collect()
17731773

0 commit comments

Comments
 (0)