@@ -1454,7 +1454,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14541454
14551455 # Clear the memory here
14561456 if not train_dataset .custom_instance_prompts :
1457- del text_encoder_one , text_encoder_two , text_encoder_three , text_encoder_four , tokenizer_one , tokenizer_two ,tokenizer_three , tokenizer_four
1457+ # delete tokenizers and text encoders except for llama (tokenizer & te four)
1458+ # as it's needed for inference with pipeline
1459+ del text_encoder_one , text_encoder_two , text_encoder_three , tokenizer_one , tokenizer_two ,tokenizer_three
1460+ if not args .validation_prompt :
1461+ del tokenizer_four , text_encoder_four
14581462 free_memory ()
14591463
14601464 # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
@@ -1739,8 +1743,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17391743 # create pipeline
17401744 pipeline = HiDreamImagePipeline .from_pretrained (
17411745 args .pretrained_model_name_or_path ,
1742- # tokenizer_4=tokenizer_4 ,
1743- # text_encoder_4=text_encoder_4 ,
1746+ tokenizer_4 = tokenizer_four ,
1747+ text_encoder_4 = text_encoder_four ,
17441748 transformer = accelerator .unwrap_model (transformer ),
17451749 revision = args .revision ,
17461750 variant = args .variant ,
0 commit comments