36
36
37
37
38
38
@patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False )
39
- def test_checkpoint_callbacks_are_last (tmp_path ):
40
- """Test that checkpoint callbacks always come last."""
39
+ def test_progressbar_and_checkpoint_callbacks_are_last (tmp_path ):
40
+ """Test that progress bar and checkpoint callbacks always come last."""
41
41
checkpoint1 = ModelCheckpoint (tmp_path / "path1" , filename = "ckpt1" , monitor = "val_loss_c1" )
42
42
checkpoint2 = ModelCheckpoint (tmp_path / "path2" , filename = "ckpt2" , monitor = "val_loss_c2" )
43
43
early_stopping = EarlyStopping (monitor = "foo" )
@@ -48,9 +48,9 @@ def test_checkpoint_callbacks_are_last(tmp_path):
48
48
# no model reference
49
49
trainer = Trainer (callbacks = [checkpoint1 , progress_bar , lr_monitor , model_summary , checkpoint2 ])
50
50
assert trainer .callbacks == [
51
- progress_bar ,
52
51
lr_monitor ,
53
52
model_summary ,
53
+ progress_bar ,
54
54
checkpoint1 ,
55
55
checkpoint2 ,
56
56
]
@@ -62,9 +62,9 @@ def test_checkpoint_callbacks_are_last(tmp_path):
62
62
cb_connector = _CallbackConnector (trainer )
63
63
cb_connector ._attach_model_callbacks ()
64
64
assert trainer .callbacks == [
65
- progress_bar ,
66
65
lr_monitor ,
67
66
model_summary ,
67
+ progress_bar ,
68
68
checkpoint1 ,
69
69
checkpoint2 ,
70
70
]
@@ -77,10 +77,10 @@ def test_checkpoint_callbacks_are_last(tmp_path):
77
77
cb_connector = _CallbackConnector (trainer )
78
78
cb_connector ._attach_model_callbacks ()
79
79
assert trainer .callbacks == [
80
- progress_bar ,
81
80
lr_monitor ,
82
81
early_stopping ,
83
82
model_summary ,
83
+ progress_bar ,
84
84
checkpoint1 ,
85
85
checkpoint2 ,
86
86
]
@@ -95,10 +95,10 @@ def test_checkpoint_callbacks_are_last(tmp_path):
95
95
cb_connector ._attach_model_callbacks ()
96
96
assert trainer .callbacks == [
97
97
batch_size_finder ,
98
- progress_bar ,
99
98
lr_monitor ,
100
99
early_stopping ,
101
100
model_summary ,
101
+ progress_bar ,
102
102
checkpoint2 ,
103
103
checkpoint1 ,
104
104
]
@@ -200,7 +200,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
200
200
trainer_callbacks = [progress_bar , EarlyStopping (monitor = "red" )],
201
201
model_callbacks = [early_stopping1 ],
202
202
)
203
- assert trainer .callbacks == [progress_bar , early_stopping1 ]
203
+ assert trainer .callbacks == [early_stopping1 , progress_bar ] # progress_bar should be last
204
204
205
205
# multiple callbacks of the same type in trainer
206
206
trainer = _attach_callbacks (
@@ -225,7 +225,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
225
225
],
226
226
model_callbacks = [early_stopping1 , lr_monitor , grad_accumulation , early_stopping2 ],
227
227
)
228
- assert trainer .callbacks == [progress_bar , early_stopping1 , lr_monitor , grad_accumulation , early_stopping2 ]
228
+ assert trainer .callbacks == [early_stopping1 , lr_monitor , grad_accumulation , early_stopping2 , progress_bar ]
229
229
230
230
class CustomProgressBar (TQDMProgressBar ): ...
231
231
0 commit comments