Skip to content

Commit 2180cae

Browse files
authored
use upload_folder (#4828)
1 parent f406a4d commit 2180cae

File tree

1 file changed

+71
-26
lines changed

1 file changed

+71
-26
lines changed

ppdiffusers/examples/dreambooth/train_dreambooth_lora.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@
2020
import math
2121
import os
2222
import sys
23+
import time
2324
import warnings
2425
from pathlib import Path
25-
from typing import Optional
26+
from typing import Optional, Type
2627

2728
import numpy as np
2829
import paddle
2930
import paddle.nn as nn
3031
import paddle.nn.functional as F
31-
from huggingface_hub import HfFolder, Repository, create_repo, whoami
32+
import requests
33+
from huggingface_hub import HfFolder, create_repo, upload_folder, whoami
3234
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
3335
fused_allreduce_gradients,
3436
)
@@ -54,6 +56,33 @@
5456
from ppdiffusers.optimization import get_scheduler
5557

5658

59+
# Since HF sometimes timeout, we need to retry uploads
60+
# Credit: https://github.com/huggingface/datasets/blob/06ae3f678651bfbb3ca7dd3274ee2f38e0e0237e/src/datasets/utils/file_utils.py#L265
61+
def _retry(
62+
func,
63+
func_args: Optional[tuple] = None,
64+
func_kwargs: Optional[dict] = None,
65+
exceptions: Type[requests.exceptions.RequestException] = requests.exceptions.RequestException,
66+
max_retries: int = 0,
67+
base_wait_time: float = 0.5,
68+
max_wait_time: float = 2,
69+
):
70+
func_args = func_args or ()
71+
func_kwargs = func_kwargs or {}
72+
retry = 0
73+
while True:
74+
try:
75+
return func(*func_args, **func_kwargs)
76+
except exceptions as err:
77+
if retry >= max_retries:
78+
raise err
79+
else:
80+
sleep_time = min(max_wait_time, base_wait_time * 2**retry) # Exponential backoff
81+
logger.info(f"{func} timed out, retrying in {sleep_time}s... [{retry/max_retries}]")
82+
time.sleep(sleep_time)
83+
retry += 1
84+
85+
5786
def url_or_path_join(*path_list):
5887
return os.path.join(*path_list) if os.path.isdir(os.path.join(*path_list)) else "/".join(path_list)
5988

@@ -561,21 +590,7 @@ def main():
561590
gc.collect()
562591

563592
if is_main_process:
564-
if args.push_to_hub:
565-
if args.hub_model_id is None:
566-
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
567-
else:
568-
repo_name = args.hub_model_id
569-
570-
create_repo(repo_name, exist_ok=True, token=args.hub_token)
571-
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
572-
573-
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
574-
if "step_*" not in gitignore:
575-
gitignore.write("step_*\n")
576-
if "epoch_*" not in gitignore:
577-
gitignore.write("epoch_*\n")
578-
elif args.output_dir is not None:
593+
if args.output_dir is not None:
579594
os.makedirs(args.output_dir, exist_ok=True)
580595

581596
# Load the tokenizer
@@ -877,15 +892,6 @@ def collate_fn(examples):
877892
unet = unwrap_model(unet)
878893
unet.save_attn_procs(args.output_dir)
879894

880-
if args.push_to_hub:
881-
save_model_card(
882-
repo_name,
883-
images=images,
884-
base_model=args.pretrained_model_name_or_path,
885-
prompt=args.instance_prompt,
886-
repo_folder=args.output_dir,
887-
)
888-
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
889895
# Final inference
890896
# Load previous pipeline
891897
pipeline = DiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, safety_checker=None)
@@ -909,6 +915,45 @@ def collate_fn(examples):
909915

910916
writer.close()
911917

918+
# logic to push to HF Hub
919+
if args.push_to_hub:
920+
if args.hub_model_id is None:
921+
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
922+
else:
923+
repo_name = args.hub_model_id
924+
925+
_retry(
926+
create_repo,
927+
func_kwargs={"repo_id": repo_name, "exist_ok": True, "token": args.hub_token},
928+
base_wait_time=1.0,
929+
max_retries=5,
930+
max_wait_time=10.0,
931+
)
932+
933+
save_model_card(
934+
repo_name,
935+
images=images,
936+
base_model=args.pretrained_model_name_or_path,
937+
prompt=args.instance_prompt,
938+
repo_folder=args.output_dir,
939+
)
940+
# Upload model
941+
logger.info(f"Pushing to {repo_name}")
942+
_retry(
943+
upload_folder,
944+
func_kwargs={
945+
"repo_id": repo_name,
946+
"repo_type": "model",
947+
"folder_path": args.output_dir,
948+
"commit_message": "End of training",
949+
"token": args.hub_token,
950+
"ignore_patterns": ["checkpoint-*/*", "logs/*"],
951+
},
952+
base_wait_time=1.0,
953+
max_retries=5,
954+
max_wait_time=20.0,
955+
)
956+
912957

913958
if __name__ == "__main__":
914959
main()

0 commit comments

Comments
 (0)