diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index d50c16a58..c093a4d81 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -246,15 +246,6 @@ def train( trainer_callbacks.append(cb) trackers.append(t) - # Now add trainer controller callbacks if requested - if (trainer_controller_args is not None) and ( - trainer_controller_args.trainer_controller_config_file is not None - ): - tc_callback = TrainerControllerCallback( - trainer_controller_args.trainer_controller_config_file, - ) - trainer_callbacks.append(tc_callback) - # Add any extra callback if passed by users if additional_callbacks is not None: for cb in additional_callbacks: @@ -500,6 +491,8 @@ def train( model ) + # Register fms-acceleration callbacks before Trainer Controller + # so that on_save() runs earlier for proper model unwrapping and checkpoint handling if framework is not None: accelerator = None if not is_accelerate_available() else trainer.accelerator @@ -514,6 +507,14 @@ def train( ): trainer.add_callback(clb) + if (trainer_controller_args is not None) and ( + trainer_controller_args.trainer_controller_config_file is not None + ): + tc_callback = TrainerControllerCallback( + trainer_controller_args.trainer_controller_config_file, + ) + trainer.add_callback(tc_callback) + trainer.train(resume_from_checkpoint) additional_metadata = {} additional_metadata["added_tokens_info"] = added_tokens_dict