@@ -2154,6 +2154,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21542154
21552155 # encode batch prompts when custom prompts are provided for each image -
21562156 if train_dataset .custom_instance_prompts :
2157+ elems_to_repeat = 1
21572158 if freeze_text_encoder :
21582159 prompt_embeds , pooled_prompt_embeds , text_ids = compute_text_embeddings (
21592160 prompts , text_encoders , tokenizers
@@ -2168,17 +2169,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21682169 max_sequence_length = args .max_sequence_length ,
21692170 add_special_tokens = add_special_tokens_t5 ,
21702171 )
2172+ else :
2173+ elems_to_repeat = len (prompts )
21712174
21722175 if not freeze_text_encoder :
21732176 prompt_embeds , pooled_prompt_embeds , text_ids = encode_prompt (
21742177 text_encoders = [text_encoder_one , text_encoder_two ],
21752178 tokenizers = [None , None ],
2176- text_input_ids_list = [tokens_one , tokens_two ],
2179+ text_input_ids_list = [
2180+ tokens_one .repeat (elems_to_repeat , 1 ),
2181+ tokens_two .repeat (elems_to_repeat , 1 ),
2182+ ],
21772183 max_sequence_length = args .max_sequence_length ,
21782184 device = accelerator .device ,
21792185 prompt = prompts ,
21802186 )
2181-
21822187 # Convert images to latent space
21832188 if args .cache_latents :
21842189 model_input = latents_cache [step ].sample ()
@@ -2371,6 +2376,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23712376 epoch = epoch ,
23722377 torch_dtype = weight_dtype ,
23732378 )
2379+ images = None
2380+ del pipeline
2381+
23742382 if freeze_text_encoder :
23752383 del text_encoder_one , text_encoder_two
23762384 free_memory ()
@@ -2448,6 +2456,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
24482456 commit_message = "End of training" ,
24492457 ignore_patterns = ["step_*" , "epoch_*" ],
24502458 )
2459+ images = None
2460+ del pipeline
24512461
24522462 accelerator .end_training ()
24532463
0 commit comments