Skip to content

Commit 22720d1

Browse files
authored
✨ Add logging for training completion and model saving in training scripts (#4048)
1 parent c8a5add commit 22720d1

File tree

5 files changed

+30
-0
lines changed

5 files changed

+30
-0
lines changed

trl/scripts/dpo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,21 @@ def main(script_args, training_args, model_args, dataset_args):
159159
# Train the model
160160
trainer.train()
161161

162+
# Log training complete
163+
trainer.accelerator.print("✅ Training completed.")
164+
162165
if training_args.eval_strategy != "no":
163166
metrics = trainer.evaluate()
164167
trainer.log_metrics("eval", metrics)
165168
trainer.save_metrics("eval", metrics)
166169

167170
# Save and push to Hub
168171
trainer.save_model(training_args.output_dir)
172+
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
173+
169174
if training_args.push_to_hub:
170175
trainer.push_to_hub(dataset_name=script_args.dataset_name)
176+
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
171177

172178

173179
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):

trl/scripts/grpo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,16 @@ def main(script_args, training_args, model_args, dataset_args):
141141
# Train the model
142142
trainer.train()
143143

144+
# Log training complete
145+
trainer.accelerator.print("✅ Training completed.")
146+
144147
# Save and push to Hub
145148
trainer.save_model(training_args.output_dir)
149+
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
150+
146151
if training_args.push_to_hub:
147152
trainer.push_to_hub(dataset_name=script_args.dataset_name)
153+
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
148154

149155

150156
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):

trl/scripts/kto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,16 @@ def main(script_args, training_args, model_args, dataset_args):
135135
# Train the model
136136
trainer.train()
137137

138+
# Log training complete
139+
trainer.accelerator.print("✅ Training completed.")
140+
138141
# Save and push to Hub
139142
trainer.save_model(training_args.output_dir)
143+
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
144+
140145
if training_args.push_to_hub:
141146
trainer.push_to_hub(dataset_name=script_args.dataset_name)
147+
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
142148

143149

144150
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):

trl/scripts/rloo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,16 @@ def main(script_args, training_args, model_args, dataset_args):
141141
# Train the model
142142
trainer.train()
143143

144+
# Log training complete
145+
trainer.accelerator.print("✅ Training completed.")
146+
144147
# Save and push to Hub
145148
trainer.save_model(training_args.output_dir)
149+
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
150+
146151
if training_args.push_to_hub:
147152
trainer.push_to_hub(dataset_name=script_args.dataset_name)
153+
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
148154

149155

150156
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):

trl/scripts/sft.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,16 @@ def main(script_args, training_args, model_args, dataset_args):
152152
# Train the model
153153
trainer.train()
154154

155+
# Log training complete
156+
trainer.accelerator.print("✅ Training completed.")
157+
155158
# Save and push to Hub
156159
trainer.save_model(training_args.output_dir)
160+
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
161+
157162
if training_args.push_to_hub:
158163
trainer.push_to_hub(dataset_name=script_args.dataset_name)
164+
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
159165

160166

161167
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):

0 commit comments

Comments
 (0)