Skip to content

Commit cb61502

Browse files
committed
#4048 and #4124
1 parent 1c4b295 commit cb61502

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

trl/scripts/reward.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def main(script_args, training_args, model_args, dataset_args):
5353
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
5454
"dataset and `dataset_name` will be ignored."
5555
)
56+
dataset = get_dataset(dataset_args)
5657
elif dataset_args.datasets and not script_args.dataset_name:
5758
dataset = get_dataset(dataset_args)
5859
elif not dataset_args.datasets and script_args.dataset_name:
@@ -74,10 +75,16 @@ def main(script_args, training_args, model_args, dataset_args):
7475
# Train the model
7576
trainer.train()
7677

78+
# Log training complete
79+
trainer.accelerator.print("✅ Training completed.")
80+
7781
# Save and push to Hub
7882
trainer.save_model(training_args.output_dir)
83+
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
84+
7985
if training_args.push_to_hub:
8086
trainer.push_to_hub(dataset_name=script_args.dataset_name)
87+
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
8188

8289

8390
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):

0 commit comments

Comments
 (0)