diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 62dd49c26cc71..0de9e8323832c 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -212,20 +212,22 @@ def _attach_model_callbacks(self) -> None: @staticmethod def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]: - """Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` callbacks - to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as - the order of all other callbacks. + """Reorders a list of callbacks such that: + + 1. All `tuner-specific` callbacks appear at the beginning. + 2. `ProgressBar` followed by `ModelCheckpoint` callbacks appear at the end. + 3. All other callbacks maintain their relative order. Args: - callbacks: A list of callbacks. + callbacks (list[Callback]): The list of callbacks to reorder. Return: - A new list in which the first elements are tuner specific callbacks and last elements are ModelCheckpoints - if there were any present in the input. + list[Callback]: A new list with callbacks reordered as described above. """ tuner_callbacks: list[Callback] = [] other_callbacks: list[Callback] = [] + progress_bar_callbacks: list[Callback] = [] checkpoint_callbacks: list[Callback] = [] for cb in callbacks: @@ -233,10 +235,12 @@ def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]: tuner_callbacks.append(cb) elif isinstance(cb, Checkpoint): checkpoint_callbacks.append(cb) + elif isinstance(cb, ProgressBar): + progress_bar_callbacks.append(cb) else: other_callbacks.append(cb) - return tuner_callbacks + other_callbacks + checkpoint_callbacks + return tuner_callbacks + other_callbacks + progress_bar_callbacks + checkpoint_callbacks def _validate_callbacks_list(callbacks: list[Callback]) -> None: diff --git a/tests/tests_pytorch/callbacks/test_callbacks.py b/tests/tests_pytorch/callbacks/test_callbacks.py index 34749087bfb97..f6f7a8cd30838 100644 --- a/tests/tests_pytorch/callbacks/test_callbacks.py +++ b/tests/tests_pytorch/callbacks/test_callbacks.py @@ -83,7 +83,9 @@ def configure_callbacks(self): return model_callback_mock model = TestModel() - trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, enable_checkpointing=False) + trainer = Trainer( + default_root_dir=tmp_path, fast_dev_run=True, enable_checkpointing=False, enable_progress_bar=False + ) callbacks_before_fit = trainer.callbacks.copy() assert callbacks_before_fit diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index bb8c365bb684c..779e3c12999cf 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -36,8 +36,8 @@ @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) -def test_checkpoint_callbacks_are_last(tmp_path): - """Test that checkpoint callbacks always come last.""" +def test_progressbar_and_checkpoint_callbacks_are_last(tmp_path): + """Test that progress bar and checkpoint callbacks always come last.""" checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1") checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2") early_stopping = EarlyStopping(monitor="foo") @@ -48,9 +48,9 @@ def test_checkpoint_callbacks_are_last(tmp_path): # no model reference trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2]) assert trainer.callbacks == [ - progress_bar, lr_monitor, model_summary, + progress_bar, checkpoint1, checkpoint2, ] @@ -62,9 +62,9 @@ def test_checkpoint_callbacks_are_last(tmp_path): cb_connector = _CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [ - progress_bar, lr_monitor, model_summary, + progress_bar, checkpoint1, checkpoint2, ] @@ -77,10 +77,10 @@ def test_checkpoint_callbacks_are_last(tmp_path): cb_connector = _CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [ - progress_bar, lr_monitor, early_stopping, model_summary, + progress_bar, checkpoint1, checkpoint2, ] @@ -95,10 +95,10 @@ def test_checkpoint_callbacks_are_last(tmp_path): cb_connector._attach_model_callbacks() assert trainer.callbacks == [ batch_size_finder, - progress_bar, lr_monitor, early_stopping, model_summary, + progress_bar, checkpoint2, checkpoint1, ] @@ -200,7 +200,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")], model_callbacks=[early_stopping1], ) - assert trainer.callbacks == [progress_bar, early_stopping1] + assert trainer.callbacks == [early_stopping1, progress_bar] # progress_bar should be last # multiple callbacks of the same type in trainer trainer = _attach_callbacks( @@ -225,7 +225,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): ], model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2], ) - assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2] + assert trainer.callbacks == [early_stopping1, lr_monitor, grad_accumulation, early_stopping2, progress_bar] class CustomProgressBar(TQDMProgressBar): ...