Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,8 +1007,7 @@ def encode_prompt(
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
)

text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)

return prompt_embeds, pooled_prompt_embeds, text_ids

Expand Down