@@ -246,15 +246,6 @@ def train(
246246 trainer_callbacks .append (cb )
247247 trackers .append (t )
248248
249- # Now add trainer controller callbacks if requested
250- if (trainer_controller_args is not None ) and (
251- trainer_controller_args .trainer_controller_config_file is not None
252- ):
253- tc_callback = TrainerControllerCallback (
254- trainer_controller_args .trainer_controller_config_file ,
255- )
256- trainer_callbacks .append (tc_callback )
257-
258249 # Add any extra callback if passed by users
259250 if additional_callbacks is not None :
260251 for cb in additional_callbacks :
@@ -500,6 +491,8 @@ def train(
500491 model
501492 )
502493
494+ # Register fms-acceleration callbacks before Trainer Controller
495+ # so that on_save() runs earlier for proper model unwrapping and checkpoint handling
503496 if framework is not None :
504497 accelerator = None if not is_accelerate_available () else trainer .accelerator
505498
@@ -514,6 +507,14 @@ def train(
514507 ):
515508 trainer .add_callback (clb )
516509
510+ if (trainer_controller_args is not None ) and (
511+ trainer_controller_args .trainer_controller_config_file is not None
512+ ):
513+ tc_callback = TrainerControllerCallback (
514+ trainer_controller_args .trainer_controller_config_file ,
515+ )
516+ trainer .add_callback (tc_callback )
517+
517518 trainer .train (resume_from_checkpoint )
518519 additional_metadata = {}
519520 additional_metadata ["added_tokens_info" ] = added_tokens_dict
0 commit comments