Skip to content

Commit ebf5743

Browse files
feat: ensure fms-acceleration callbacks run before TrainerController (#620)
Signed-off-by: yashasvi <[email protected]>
1 parent 5187516 commit ebf5743

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tuning/sft_trainer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,6 @@ def train(
246246
trainer_callbacks.append(cb)
247247
trackers.append(t)
248248

249-
# Now add trainer controller callbacks if requested
250-
if (trainer_controller_args is not None) and (
251-
trainer_controller_args.trainer_controller_config_file is not None
252-
):
253-
tc_callback = TrainerControllerCallback(
254-
trainer_controller_args.trainer_controller_config_file,
255-
)
256-
trainer_callbacks.append(tc_callback)
257-
258249
# Add any extra callback if passed by users
259250
if additional_callbacks is not None:
260251
for cb in additional_callbacks:
@@ -500,6 +491,8 @@ def train(
500491
model
501492
)
502493

494+
# Register fms-acceleration callbacks before Trainer Controller
495+
# so that on_save() runs earlier for proper model unwrapping and checkpoint handling
503496
if framework is not None:
504497
accelerator = None if not is_accelerate_available() else trainer.accelerator
505498

@@ -514,6 +507,14 @@ def train(
514507
):
515508
trainer.add_callback(clb)
516509

510+
if (trainer_controller_args is not None) and (
511+
trainer_controller_args.trainer_controller_config_file is not None
512+
):
513+
tc_callback = TrainerControllerCallback(
514+
trainer_controller_args.trainer_controller_config_file,
515+
)
516+
trainer.add_callback(tc_callback)
517+
517518
trainer.train(resume_from_checkpoint)
518519
additional_metadata = {}
519520
additional_metadata["added_tokens_info"] = added_tokens_dict

0 commit comments

Comments
 (0)