Skip to content

Commit 93bda6b

Browse files
fix: directly save final ckpt in save_model_dir
Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent 3344193 commit 93bda6b

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

tuning/config/acceleration_configs/fast_moe.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,30 +97,25 @@ def on_save(
9797
"""
9898

9999
def checkpoint(checkpoint_dir, save_dir):
100-
hf_converted_output_dir = os.path.join(
101-
save_dir, "hf_converted_checkpoint"
102-
)
103-
if os.path.exists(hf_converted_output_dir):
100+
if os.path.exists(save_dir):
104101
# If the folder already exists
105102
# we return, since this is possible to happen
106103
# saving the checkpointing at the end of the training
107104
return
108-
os.mkdir(hf_converted_output_dir)
105+
109106
try:
110107
recover_safetensors_from_dcp(
111108
checkpoint_dir,
112109
self.pretrained_model_name_or_path,
113-
hf_converted_output_dir,
110+
save_dir,
114111
)
115112
# Save tokenizer
116113
if self.trainer.processing_class:
117-
self.trainer.processing_class.save_pretrained(
118-
hf_converted_output_dir
119-
)
114+
self.trainer.processing_class.save_pretrained(save_dir)
120115
# Save training args
121116
torch.save(
122117
args,
123-
os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME),
118+
os.path.join(save_dir, TRAINING_ARGS_NAME),
124119
)
125120

126121
# Unwrap FSDP module
@@ -135,16 +130,14 @@ def checkpoint(checkpoint_dir, save_dir):
135130
list(config_dict["target_modules"])
136131
)
137132
with open(
138-
os.path.join(
139-
hf_converted_output_dir, "adapter_config.json"
140-
),
133+
os.path.join(save_dir, "adapter_config.json"),
141134
"w",
142135
encoding="utf-8",
143136
) as f:
144137
json.dump(config_dict, f, indent=2)
145138

146139
else:
147-
model.config.save_pretrained(hf_converted_output_dir)
140+
model.config.save_pretrained(save_dir)
148141

149142
except Exception as e:
150143
raise ValueError(
@@ -157,15 +150,19 @@ def checkpoint(checkpoint_dir, save_dir):
157150
checkpoint_dir = os.path.join(
158151
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
159152
)
160-
checkpoint(checkpoint_dir, checkpoint_dir)
153+
hf_converted_path = os.path.join(
154+
checkpoint_dir, "hf_converted_checkpoint"
155+
)
156+
if not os.path.exists(hf_converted_path):
157+
os.makedirs(hf_converted_path)
158+
checkpoint(checkpoint_dir, hf_converted_path)
161159

162160
# If final save directory is provided, save the model there
163161
if (
164162
getattr(self, "save_model_dir", None)
165163
and state.global_step == state.max_steps
166164
):
167-
if not os.path.exists(self.save_model_dir):
168-
os.mkdir(self.save_model_dir)
165+
os.makedirs(self.save_model_dir, exist_ok=True)
169166
checkpoint(checkpoint_dir, self.save_model_dir)
170167

171168
callbacks.append(

0 commit comments

Comments
 (0)