@@ -97,30 +97,25 @@ def on_save(
9797 """
9898
9999 def checkpoint (checkpoint_dir , save_dir ):
100- hf_converted_output_dir = os .path .join (
101- save_dir , "hf_converted_checkpoint"
102- )
103- if os .path .exists (hf_converted_output_dir ):
100+ if os .path .exists (save_dir ):
104101 # If the folder already exists
105102 # we return, since this is possible to happen
106103 # saving the checkpointing at the end of the training
107104 return
108- os . mkdir ( hf_converted_output_dir )
105+
109106 try :
110107 recover_safetensors_from_dcp (
111108 checkpoint_dir ,
112109 self .pretrained_model_name_or_path ,
113- hf_converted_output_dir ,
110+ save_dir ,
114111 )
115112 # Save tokenizer
116113 if self .trainer .processing_class :
117- self .trainer .processing_class .save_pretrained (
118- hf_converted_output_dir
119- )
114+ self .trainer .processing_class .save_pretrained (save_dir )
120115 # Save training args
121116 torch .save (
122117 args ,
123- os .path .join (hf_converted_output_dir , TRAINING_ARGS_NAME ),
118+ os .path .join (save_dir , TRAINING_ARGS_NAME ),
124119 )
125120
126121 # Unwrap FSDP module
@@ -135,16 +130,14 @@ def checkpoint(checkpoint_dir, save_dir):
135130 list (config_dict ["target_modules" ])
136131 )
137132 with open (
138- os .path .join (
139- hf_converted_output_dir , "adapter_config.json"
140- ),
133+ os .path .join (save_dir , "adapter_config.json" ),
141134 "w" ,
142135 encoding = "utf-8" ,
143136 ) as f :
144137 json .dump (config_dict , f , indent = 2 )
145138
146139 else :
147- model .config .save_pretrained (hf_converted_output_dir )
140+ model .config .save_pretrained (save_dir )
148141
149142 except Exception as e :
150143 raise ValueError (
@@ -157,15 +150,19 @@ def checkpoint(checkpoint_dir, save_dir):
157150 checkpoint_dir = os .path .join (
158151 args .output_dir , f"{ PREFIX_CHECKPOINT_DIR } -{ state .global_step } "
159152 )
160- checkpoint (checkpoint_dir , checkpoint_dir )
153+ hf_converted_path = os .path .join (
154+ checkpoint_dir , "hf_converted_checkpoint"
155+ )
156+ if not os .path .exists (hf_converted_path ):
157+ os .makedirs (hf_converted_path )
158+ checkpoint (checkpoint_dir , hf_converted_path )
161159
162160 # If final save directory is provided, save the model there
163161 if (
164162 getattr (self , "save_model_dir" , None )
165163 and state .global_step == state .max_steps
166164 ):
167- if not os .path .exists (self .save_model_dir ):
168- os .mkdir (self .save_model_dir )
165+ os .makedirs (self .save_model_dir , exist_ok = True )
169166 checkpoint (checkpoint_dir , self .save_model_dir )
170167
171168 callbacks .append (
0 commit comments