20
20
import math
21
21
import os
22
22
import sys
23
+ import time
23
24
import warnings
24
25
from pathlib import Path
25
- from typing import Optional
26
+ from typing import Optional , Type
26
27
27
28
import numpy as np
28
29
import paddle
29
30
import paddle .nn as nn
30
31
import paddle .nn .functional as F
31
- from huggingface_hub import HfFolder , Repository , create_repo , whoami
32
+ import requests
33
+ from huggingface_hub import HfFolder , create_repo , upload_folder , whoami
32
34
from paddle .distributed .fleet .utils .hybrid_parallel_util import (
33
35
fused_allreduce_gradients ,
34
36
)
54
56
from ppdiffusers .optimization import get_scheduler
55
57
56
58
59
+ # Since HF sometimes timeout, we need to retry uploads
60
+ # Credit: https://github.com/huggingface/datasets/blob/06ae3f678651bfbb3ca7dd3274ee2f38e0e0237e/src/datasets/utils/file_utils.py#L265
61
+ def _retry (
62
+ func ,
63
+ func_args : Optional [tuple ] = None ,
64
+ func_kwargs : Optional [dict ] = None ,
65
+ exceptions : Type [requests .exceptions .RequestException ] = requests .exceptions .RequestException ,
66
+ max_retries : int = 0 ,
67
+ base_wait_time : float = 0.5 ,
68
+ max_wait_time : float = 2 ,
69
+ ):
70
+ func_args = func_args or ()
71
+ func_kwargs = func_kwargs or {}
72
+ retry = 0
73
+ while True :
74
+ try :
75
+ return func (* func_args , ** func_kwargs )
76
+ except exceptions as err :
77
+ if retry >= max_retries :
78
+ raise err
79
+ else :
80
+ sleep_time = min (max_wait_time , base_wait_time * 2 ** retry ) # Exponential backoff
81
+ logger .info (f"{ func } timed out, retrying in { sleep_time } s... [{ retry / max_retries } ]" )
82
+ time .sleep (sleep_time )
83
+ retry += 1
84
+
85
+
57
86
def url_or_path_join (* path_list ):
58
87
return os .path .join (* path_list ) if os .path .isdir (os .path .join (* path_list )) else "/" .join (path_list )
59
88
@@ -561,21 +590,7 @@ def main():
561
590
gc .collect ()
562
591
563
592
if is_main_process :
564
- if args .push_to_hub :
565
- if args .hub_model_id is None :
566
- repo_name = get_full_repo_name (Path (args .output_dir ).name , token = args .hub_token )
567
- else :
568
- repo_name = args .hub_model_id
569
-
570
- create_repo (repo_name , exist_ok = True , token = args .hub_token )
571
- repo = Repository (args .output_dir , clone_from = repo_name , token = args .hub_token )
572
-
573
- with open (os .path .join (args .output_dir , ".gitignore" ), "w+" ) as gitignore :
574
- if "step_*" not in gitignore :
575
- gitignore .write ("step_*\n " )
576
- if "epoch_*" not in gitignore :
577
- gitignore .write ("epoch_*\n " )
578
- elif args .output_dir is not None :
593
+ if args .output_dir is not None :
579
594
os .makedirs (args .output_dir , exist_ok = True )
580
595
581
596
# Load the tokenizer
@@ -877,15 +892,6 @@ def collate_fn(examples):
877
892
unet = unwrap_model (unet )
878
893
unet .save_attn_procs (args .output_dir )
879
894
880
- if args .push_to_hub :
881
- save_model_card (
882
- repo_name ,
883
- images = images ,
884
- base_model = args .pretrained_model_name_or_path ,
885
- prompt = args .instance_prompt ,
886
- repo_folder = args .output_dir ,
887
- )
888
- repo .push_to_hub (commit_message = "End of training" , blocking = False , auto_lfs_prune = True )
889
895
# Final inference
890
896
# Load previous pipeline
891
897
pipeline = DiffusionPipeline .from_pretrained (args .pretrained_model_name_or_path , safety_checker = None )
@@ -909,6 +915,45 @@ def collate_fn(examples):
909
915
910
916
writer .close ()
911
917
918
+ # logic to push to HF Hub
919
+ if args .push_to_hub :
920
+ if args .hub_model_id is None :
921
+ repo_name = get_full_repo_name (Path (args .output_dir ).name , token = args .hub_token )
922
+ else :
923
+ repo_name = args .hub_model_id
924
+
925
+ _retry (
926
+ create_repo ,
927
+ func_kwargs = {"repo_id" : repo_name , "exist_ok" : True , "token" : args .hub_token },
928
+ base_wait_time = 1.0 ,
929
+ max_retries = 5 ,
930
+ max_wait_time = 10.0 ,
931
+ )
932
+
933
+ save_model_card (
934
+ repo_name ,
935
+ images = images ,
936
+ base_model = args .pretrained_model_name_or_path ,
937
+ prompt = args .instance_prompt ,
938
+ repo_folder = args .output_dir ,
939
+ )
940
+ # Upload model
941
+ logger .info (f"Pushing to { repo_name } " )
942
+ _retry (
943
+ upload_folder ,
944
+ func_kwargs = {
945
+ "repo_id" : repo_name ,
946
+ "repo_type" : "model" ,
947
+ "folder_path" : args .output_dir ,
948
+ "commit_message" : "End of training" ,
949
+ "token" : args .hub_token ,
950
+ "ignore_patterns" : ["checkpoint-*/*" , "logs/*" ],
951
+ },
952
+ base_wait_time = 1.0 ,
953
+ max_retries = 5 ,
954
+ max_wait_time = 20.0 ,
955
+ )
956
+
912
957
913
958
if __name__ == "__main__" :
914
959
main ()
0 commit comments