@@ -962,7 +962,7 @@ def encode_prompt(
962
962
prompt = prompt ,
963
963
device = device if device is not None else text_encoder .device ,
964
964
num_images_per_prompt = num_images_per_prompt ,
965
- text_input_ids = text_input_ids_list [i ],
965
+ text_input_ids = text_input_ids_list [i ] if text_input_ids_list else None ,
966
966
)
967
967
clip_prompt_embeds_list .append (prompt_embeds )
968
968
clip_pooled_prompt_embeds_list .append (pooled_prompt_embeds )
@@ -976,7 +976,7 @@ def encode_prompt(
976
976
max_sequence_length ,
977
977
prompt = prompt ,
978
978
num_images_per_prompt = num_images_per_prompt ,
979
- text_input_ids = text_input_ids_list [: - 1 ],
979
+ text_input_ids = text_input_ids_list [- 1 ] if text_input_ids_list else None ,
980
980
device = device if device is not None else text_encoders [- 1 ].device ,
981
981
)
982
982
@@ -1491,6 +1491,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1491
1491
) = accelerator .prepare (
1492
1492
transformer , text_encoder_one , text_encoder_two , optimizer , train_dataloader , lr_scheduler
1493
1493
)
1494
+ assert text_encoder_one is not None
1495
+ assert text_encoder_two is not None
1496
+ assert text_encoder_three is not None
1494
1497
else :
1495
1498
transformer , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1496
1499
transformer , optimizer , train_dataloader , lr_scheduler
@@ -1598,7 +1601,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1598
1601
tokens_three = tokenize_prompt (tokenizer_three , prompts )
1599
1602
prompt_embeds , pooled_prompt_embeds = encode_prompt (
1600
1603
text_encoders = [text_encoder_one , text_encoder_two , text_encoder_three ],
1601
- tokenizers = [None , None , tokenizer_three ],
1604
+ tokenizers = [None , None , None ],
1602
1605
prompt = prompts ,
1603
1606
max_sequence_length = args .max_sequence_length ,
1604
1607
text_input_ids_list = [tokens_one , tokens_two , tokens_three ],
@@ -1608,7 +1611,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1608
1611
prompt_embeds , pooled_prompt_embeds = encode_prompt (
1609
1612
text_encoders = [text_encoder_one , text_encoder_two , text_encoder_three ],
1610
1613
tokenizers = [None , None , tokenizer_three ],
1611
- prompt = prompts ,
1614
+ prompt = args . instance_prompt ,
1612
1615
max_sequence_length = args .max_sequence_length ,
1613
1616
text_input_ids_list = [tokens_one , tokens_two , tokens_three ],
1614
1617
)
@@ -1685,10 +1688,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1685
1688
1686
1689
accelerator .backward (loss )
1687
1690
if accelerator .sync_gradients :
1688
- params_to_clip = itertools .chain (
1689
- transformer_lora_parameters ,
1690
- text_lora_parameters_one ,
1691
- text_lora_parameters_two if args .train_text_encoder else transformer_lora_parameters ,
1691
+ params_to_clip = (
1692
+ itertools .chain (
1693
+ transformer_lora_parameters , text_lora_parameters_one , text_lora_parameters_two
1694
+ )
1695
+ if args .train_text_encoder
1696
+ else transformer_lora_parameters
1692
1697
)
1693
1698
accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
1694
1699
@@ -1741,13 +1746,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1741
1746
text_encoder_one , text_encoder_two , text_encoder_three = load_text_encoders (
1742
1747
text_encoder_cls_one , text_encoder_cls_two , text_encoder_cls_three
1743
1748
)
1744
- else :
1745
- text_encoder_three = text_encoder_cls_three .from_pretrained (
1746
- args .pretrained_model_name_or_path ,
1747
- subfolder = "text_encoder_3" ,
1748
- revision = args .revision ,
1749
- variant = args .variant ,
1750
- )
1751
1749
pipeline = StableDiffusion3Pipeline .from_pretrained (
1752
1750
args .pretrained_model_name_or_path ,
1753
1751
vae = vae ,
@@ -1767,7 +1765,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1767
1765
pipeline_args = pipeline_args ,
1768
1766
epoch = epoch ,
1769
1767
)
1770
- del text_encoder_one , text_encoder_two , text_encoder_three
1768
+ if not args .train_text_encoder :
1769
+ del text_encoder_one , text_encoder_two , text_encoder_three
1770
+
1771
1771
torch .cuda .empty_cache ()
1772
1772
gc .collect ()
1773
1773
0 commit comments