Skip to content

Commit d1a2c3a

Browse files
committed
Reorder callback list to prioritize tuner-specific callbacks and maintain order for ProgressBar and ModelCheckpoint
1 parent 918a1a6 commit d1a2c3a

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,31 +212,35 @@ def _attach_model_callbacks(self) -> None:
212212

213213
@staticmethod
214214
def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]:
215-
"""Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` callbacks
216-
to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as
217-
the order of all other callbacks.
215+
"""Reorders a list of callbacks such that:
216+
217+
1. All `tuner-specific` callbacks appear at the beginning.
218+
2. `ProgressBar` followed by `ModelCheckpoint` callbacks appear at the end.
219+
3. All other callbacks maintain their relative order.
218220
219221
Args:
220-
callbacks: A list of callbacks.
222+
callbacks (list[Callback]): The list of callbacks to reorder.
221223
222224
Return:
223-
A new list in which the first elements are tuner specific callbacks and last elements are ModelCheckpoints
224-
if there were any present in the input.
225+
list[Callback]: A new list with callbacks reordered as described above.
225226
226227
"""
227228
tuner_callbacks: list[Callback] = []
228229
other_callbacks: list[Callback] = []
230+
progress_bar_callbacks: list[Callback] = []
229231
checkpoint_callbacks: list[Callback] = []
230232

231233
for cb in callbacks:
232234
if isinstance(cb, (BatchSizeFinder, LearningRateFinder)):
233235
tuner_callbacks.append(cb)
234236
elif isinstance(cb, Checkpoint):
235237
checkpoint_callbacks.append(cb)
238+
elif isinstance(cb, ProgressBar):
239+
progress_bar_callbacks.append(cb)
236240
else:
237241
other_callbacks.append(cb)
238242

239-
return tuner_callbacks + other_callbacks + checkpoint_callbacks
243+
return tuner_callbacks + other_callbacks + progress_bar_callbacks + checkpoint_callbacks
240244

241245

242246
def _validate_callbacks_list(callbacks: list[Callback]) -> None:

0 commit comments

Comments
 (0)