Skip to content

Commit aa6b6e2

Browse files
committed
shapes
1 parent c296b6f commit aa6b6e2

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,7 +1454,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14541454

14551455
# Clear the memory here
14561456
if not train_dataset.custom_instance_prompts:
1457-
del text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four, tokenizer_one, tokenizer_two,tokenizer_three, tokenizer_four
1457+
# delete tokenizers and text encoders except for llama (tokenizer & te four)
1458+
# as it's needed for inference with pipeline
1459+
del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two,tokenizer_three
1460+
if not args.validation_prompt:
1461+
del tokenizer_four, text_encoder_four
14581462
free_memory()
14591463

14601464
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
@@ -1739,8 +1743,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17391743
# create pipeline
17401744
pipeline = HiDreamImagePipeline.from_pretrained(
17411745
args.pretrained_model_name_or_path,
1742-
# tokenizer_4=tokenizer_4,
1743-
# text_encoder_4=text_encoder_4,
1746+
tokenizer_4=tokenizer_four,
1747+
text_encoder_4=text_encoder_four,
17441748
transformer=accelerator.unwrap_model(transformer),
17451749
revision=args.revision,
17461750
variant=args.variant,

0 commit comments

Comments
 (0)