Skip to content

Commit 1f3adba

Browse files
authored
Merge branch 'main' into refac/training_utils.py
2 parents 71a9899 + 92d2baf commit 1f3adba

22 files changed

+462
-131
lines changed

benchmarks/push_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pandas as pd
55
from huggingface_hub import hf_hub_download, upload_file
6-
from huggingface_hub.utils._errors import EntryNotFoundError
6+
from huggingface_hub.utils import EntryNotFoundError
77

88

99
sys.path.append(".")

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,6 @@ def encode_prompt(
985985
text_input_ids_list=None,
986986
):
987987
prompt = [prompt] if isinstance(prompt, str) else prompt
988-
batch_size = len(prompt)
989988
dtype = text_encoders[0].dtype
990989

991990
pooled_prompt_embeds = _encode_prompt_with_clip(
@@ -1007,8 +1006,7 @@ def encode_prompt(
10071006
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
10081007
)
10091008

1010-
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
1011-
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
1009+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
10121010

10131011
return prompt_embeds, pooled_prompt_embeds, text_ids
10141012

0 commit comments

Comments
 (0)