Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/lightning/pytorch/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,31 +212,35 @@ def _attach_model_callbacks(self) -> None:

@staticmethod
def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]:
"""Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` callbacks
to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as
the order of all other callbacks.
"""Reorders a list of callbacks such that:

1. All `tuner-specific` callbacks appear at the beginning.
2. `ProgressBar` followed by `ModelCheckpoint` callbacks appear at the end.
3. All other callbacks maintain their relative order.

Args:
callbacks: A list of callbacks.
callbacks (list[Callback]): The list of callbacks to reorder.

Return:
A new list in which the first elements are tuner specific callbacks and last elements are ModelCheckpoints
if there were any present in the input.
list[Callback]: A new list with callbacks reordered as described above.

"""
tuner_callbacks: list[Callback] = []
other_callbacks: list[Callback] = []
progress_bar_callbacks: list[Callback] = []
checkpoint_callbacks: list[Callback] = []

for cb in callbacks:
if isinstance(cb, (BatchSizeFinder, LearningRateFinder)):
tuner_callbacks.append(cb)
elif isinstance(cb, Checkpoint):
checkpoint_callbacks.append(cb)
elif isinstance(cb, ProgressBar):
progress_bar_callbacks.append(cb)
else:
other_callbacks.append(cb)

return tuner_callbacks + other_callbacks + checkpoint_callbacks
return tuner_callbacks + other_callbacks + progress_bar_callbacks + checkpoint_callbacks


def _validate_callbacks_list(callbacks: list[Callback]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_pytorch/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def configure_callbacks(self):
return model_callback_mock

model = TestModel()
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, enable_checkpointing=False)
trainer = Trainer(
default_root_dir=tmp_path, fast_dev_run=True, enable_checkpointing=False, enable_progress_bar=False
)

callbacks_before_fit = trainer.callbacks.copy()
assert callbacks_before_fit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_checkpoint_callbacks_are_last(tmp_path):
"""Test that checkpoint callbacks always come last."""
def test_progressbar_and_checkpoint_callbacks_are_last(tmp_path):
"""Test that progress bar and checkpoint callbacks always come last."""
checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1")
checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2")
early_stopping = EarlyStopping(monitor="foo")
Expand All @@ -48,9 +48,9 @@ def test_checkpoint_callbacks_are_last(tmp_path):
# no model reference
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
assert trainer.callbacks == [
progress_bar,
lr_monitor,
model_summary,
progress_bar,
checkpoint1,
checkpoint2,
]
Expand All @@ -62,9 +62,9 @@ def test_checkpoint_callbacks_are_last(tmp_path):
cb_connector = _CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [
progress_bar,
lr_monitor,
model_summary,
progress_bar,
checkpoint1,
checkpoint2,
]
Expand All @@ -77,10 +77,10 @@ def test_checkpoint_callbacks_are_last(tmp_path):
cb_connector = _CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [
progress_bar,
lr_monitor,
early_stopping,
model_summary,
progress_bar,
checkpoint1,
checkpoint2,
]
Expand All @@ -95,10 +95,10 @@ def test_checkpoint_callbacks_are_last(tmp_path):
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [
batch_size_finder,
progress_bar,
lr_monitor,
early_stopping,
model_summary,
progress_bar,
checkpoint2,
checkpoint1,
]
Expand Down Expand Up @@ -200,7 +200,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")],
model_callbacks=[early_stopping1],
)
assert trainer.callbacks == [progress_bar, early_stopping1]
assert trainer.callbacks == [early_stopping1, progress_bar] # progress_bar should be last

# multiple callbacks of the same type in trainer
trainer = _attach_callbacks(
Expand All @@ -225,7 +225,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
],
model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2],
)
assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2]
assert trainer.callbacks == [early_stopping1, lr_monitor, grad_accumulation, early_stopping2, progress_bar]

class CustomProgressBar(TQDMProgressBar): ...

Expand Down
Loading