Skip to content

Commit 43672b4

Browse files
aandywsayakpaul
andauthored
Fix "push_to_hub only create repo in consistency model lora SDXL training script" (#6102)
* fix * style fix --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 9df3d84 commit 43672b4

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from accelerate.logging import get_logger
3939
from accelerate.utils import ProjectConfiguration, set_seed
4040
from braceexpand import braceexpand
41-
from huggingface_hub import create_repo
41+
from huggingface_hub import create_repo, upload_folder
4242
from packaging import version
4343
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
4444
from torch.utils.data import default_collate
@@ -847,7 +847,7 @@ def main(args):
847847
os.makedirs(args.output_dir, exist_ok=True)
848848

849849
if args.push_to_hub:
850-
create_repo(
850+
repo_id = create_repo(
851851
repo_id=args.hub_model_id or Path(args.output_dir).name,
852852
exist_ok=True,
853853
token=args.hub_token,
@@ -1366,6 +1366,14 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
13661366
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
13671367
StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
13681368

1369+
if args.push_to_hub:
1370+
upload_folder(
1371+
repo_id=repo_id,
1372+
folder_path=args.output_dir,
1373+
commit_message="End of training",
1374+
ignore_patterns=["step_*", "epoch_*"],
1375+
)
1376+
13691377
accelerator.end_training()
13701378

13711379

examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from accelerate.logging import get_logger
4040
from accelerate.utils import ProjectConfiguration, set_seed
4141
from braceexpand import braceexpand
42-
from huggingface_hub import create_repo
42+
from huggingface_hub import create_repo, upload_folder
4343
from packaging import version
4444
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
4545
from torch.utils.data import default_collate
@@ -842,7 +842,7 @@ def main(args):
842842
os.makedirs(args.output_dir, exist_ok=True)
843843

844844
if args.push_to_hub:
845-
create_repo(
845+
repo_id = create_repo(
846846
repo_id=args.hub_model_id or Path(args.output_dir).name,
847847
exist_ok=True,
848848
token=args.hub_token,
@@ -1424,6 +1424,14 @@ def compute_embeddings(
14241424
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
14251425
StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
14261426

1427+
if args.push_to_hub:
1428+
upload_folder(
1429+
repo_id=repo_id,
1430+
folder_path=args.output_dir,
1431+
commit_message="End of training",
1432+
ignore_patterns=["step_*", "epoch_*"],
1433+
)
1434+
14271435
accelerator.end_training()
14281436

14291437

examples/consistency_distillation/train_lcm_distill_sd_wds.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from accelerate.logging import get_logger
3939
from accelerate.utils import ProjectConfiguration, set_seed
4040
from braceexpand import braceexpand
41-
from huggingface_hub import create_repo
41+
from huggingface_hub import create_repo, upload_folder
4242
from packaging import version
4343
from torch.utils.data import default_collate
4444
from torchvision import transforms
@@ -835,7 +835,7 @@ def main(args):
835835
os.makedirs(args.output_dir, exist_ok=True)
836836

837837
if args.push_to_hub:
838-
create_repo(
838+
repo_id = create_repo(
839839
repo_id=args.hub_model_id or Path(args.output_dir).name,
840840
exist_ok=True,
841841
token=args.hub_token,
@@ -1354,6 +1354,14 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
13541354
target_unet = accelerator.unwrap_model(target_unet)
13551355
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
13561356

1357+
if args.push_to_hub:
1358+
upload_folder(
1359+
repo_id=repo_id,
1360+
folder_path=args.output_dir,
1361+
commit_message="End of training",
1362+
ignore_patterns=["step_*", "epoch_*"],
1363+
)
1364+
13571365
accelerator.end_training()
13581366

13591367

examples/consistency_distillation/train_lcm_distill_sdxl_wds.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from accelerate.logging import get_logger
4040
from accelerate.utils import ProjectConfiguration, set_seed
4141
from braceexpand import braceexpand
42-
from huggingface_hub import create_repo
42+
from huggingface_hub import create_repo, upload_folder
4343
from packaging import version
4444
from torch.utils.data import default_collate
4545
from torchvision import transforms
@@ -875,7 +875,7 @@ def main(args):
875875
os.makedirs(args.output_dir, exist_ok=True)
876876

877877
if args.push_to_hub:
878-
create_repo(
878+
repo_id = create_repo(
879879
repo_id=args.hub_model_id or Path(args.output_dir).name,
880880
exist_ok=True,
881881
token=args.hub_token,
@@ -1457,6 +1457,14 @@ def compute_embeddings(
14571457
target_unet = accelerator.unwrap_model(target_unet)
14581458
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
14591459

1460+
if args.push_to_hub:
1461+
upload_folder(
1462+
repo_id=repo_id,
1463+
folder_path=args.output_dir,
1464+
commit_message="End of training",
1465+
ignore_patterns=["step_*", "epoch_*"],
1466+
)
1467+
14601468
accelerator.end_training()
14611469

14621470

0 commit comments

Comments
 (0)