Skip to content

Commit d993e16

Browse files
committed
prompt embeds
1 parent 33385c9 commit d993e16

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -907,8 +907,8 @@ def _encode_prompt_with_llama(
907907
else:
908908
dtype = text_encoder.dtype
909909

910-
prompt_embeds = outputs.hidden_states[1:].to(dtype=dtype, device=device)
911-
prompt_embeds = torch.stack(prompt_embeds, dim=0)
910+
prompt_embeds = outputs.hidden_states[1:]
911+
prompt_embeds = torch.stack(prompt_embeds, dim=0).to(dtype=dtype, device=device)
912912
_, _, seq_len, dim = prompt_embeds.shape
913913

914914
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
@@ -1060,6 +1060,8 @@ def encode_prompt(
10601060
attention_mask=attention_mask_list[1] if attention_mask_list else None,
10611061
)
10621062

1063+
print("t5_prompt_embeds",t5_prompt_embeds.shape)
1064+
print("llama3_prompt_embeds",llama3_prompt_embeds.shape)
10631065
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
10641066

10651067
return prompt_embeds, pooled_prompt_embeds
@@ -1431,7 +1433,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14311433
prompt_embeds, pooled_prompt_embeds = encode_prompt(
14321434
text_encoders, tokenizers, prompt, args.max_sequence_length
14331435
)
1434-
prompt_embeds = prompt_embeds.to(accelerator.device)
1436+
prompt_embeds[0] = prompt_embeds[0].to(accelerator.device)
1437+
prompt_embeds[1] = prompt_embeds[1].to(accelerator.device)
14351438
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
14361439
return prompt_embeds, pooled_prompt_embeds
14371440

@@ -1587,7 +1590,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15871590
if train_dataset.custom_instance_prompts:
15881591
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers)
15891592
else:
1590-
prompt_embeds = prompt_embeds.repeat(len(prompts), 1, 1)
1593+
prompt_embeds[0] = prompt_embeds[0].repeat(len(prompts), 1, 1)
1594+
prompt_embeds[1] = prompt_embeds[1].repeat(1, len(prompts), 1, 1)
15911595
pooled_prompt_embeds = pooled_prompt_embeds.repeat(len(prompts), 1)
15921596
# Convert images to latent space
15931597
if args.cache_latents:

0 commit comments

Comments
 (0)