Skip to content

Commit 1e844d3

Browse files
fix: directly save final ckpt in save_model_dir (#626)
Signed-off-by: yashasvi <[email protected]>
1 parent 3344193 commit 1e844d3

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

tuning/config/acceleration_configs/fast_moe.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,21 @@ def on_save(
9696
Also saves the final model in save_model_dir if provided.
9797
"""
9898

99-
def checkpoint(checkpoint_dir, save_dir):
100-
hf_converted_output_dir = os.path.join(
101-
save_dir, "hf_converted_checkpoint"
102-
)
103-
if os.path.exists(hf_converted_output_dir):
99+
def checkpoint(checkpoint_dir, save_dir, is_intermediate: bool = True):
100+
if is_intermediate:
101+
hf_converted_output_dir = os.path.join(
102+
save_dir, "hf_converted_checkpoint"
103+
)
104+
else:
105+
hf_converted_output_dir = save_dir
106+
107+
if os.path.exists(hf_converted_output_dir) and is_intermediate:
104108
# If the folder already exists
105109
# we return, since this is possible to happen
106110
# saving the checkpointing at the end of the training
107111
return
108-
os.mkdir(hf_converted_output_dir)
112+
113+
os.makedirs(hf_converted_output_dir, exist_ok=True)
109114
try:
110115
recover_safetensors_from_dcp(
111116
checkpoint_dir,
@@ -165,8 +170,10 @@ def checkpoint(checkpoint_dir, save_dir):
165170
and state.global_step == state.max_steps
166171
):
167172
if not os.path.exists(self.save_model_dir):
168-
os.mkdir(self.save_model_dir)
169-
checkpoint(checkpoint_dir, self.save_model_dir)
173+
os.makedirs(self.save_model_dir, exist_ok=True)
174+
checkpoint(
175+
checkpoint_dir, self.save_model_dir, is_intermediate=False
176+
)
170177

171178
callbacks.append(
172179
ConvertAndSaveHFCheckpointAtEverySave(

0 commit comments

Comments
 (0)