Skip to content

Commit 5257b46

Browse files
committed
move prompt embeds, pooled embeds outside
1 parent 4e08343 commit 5257b46

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,7 +1591,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15911591
# encode batch prompts when custom prompts are provided for each image -
15921592
if train_dataset.custom_instance_prompts:
15931593
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers)
1594-
1594+
else:
1595+
prompt_embeds = prompt_embeds.repeat(len(prompts), 1, 1)
1596+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(len(prompts), 1)
15951597
# Convert images to latent space
15961598
if args.cache_latents:
15971599
model_input = latents_cache[step].sample()
@@ -1646,12 +1648,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16461648
# Predict the noise residual
16471649
model_pred = transformer(
16481650
hidden_states=noisy_model_input,
1649-
encoder_hidden_states=prompt_embeds.repeat(len(prompts), 1, 1)
1650-
if not train_dataset.custom_instance_prompts
1651-
else prompt_embeds,
1652-
pooled_embeds=pooled_prompt_embeds.repeat(len(prompts), 1)
1653-
if not train_dataset.custom_instance_prompts
1654-
else pooled_prompt_embeds,
1651+
encoder_hidden_states=prompt_embeds,
1652+
pooled_embeds=pooled_prompt_embeds,
16551653
timestep=timesteps,
16561654
img_sizes=img_sizes,
16571655
img_ids=img_ids,

0 commit comments

Comments
 (0)