Skip to content

Commit fa5164c

Browse files
committed
fix tests
1 parent 363a49e commit fa5164c

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

tests/tests_pytorch/callbacks/test_callbacks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def configure_callbacks(self):
8383
return model_callback_mock
8484

8585
model = TestModel()
86-
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, enable_checkpointing=False)
86+
trainer = Trainer(
87+
default_root_dir=tmp_path, fast_dev_run=True, enable_checkpointing=False, enable_progress_bar=False
88+
)
8789

8890
callbacks_before_fit = trainer.callbacks.copy()
8991
assert callbacks_before_fit

tests/tests_pytorch/trainer/connectors/test_callback_connector.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636

3737

3838
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
39-
def test_checkpoint_callbacks_are_last(tmp_path):
40-
"""Test that checkpoint callbacks always come last."""
39+
def test_progressbar_and_checkpoint_callbacks_are_last(tmp_path):
40+
"""Test that progress bar and checkpoint callbacks always come last."""
4141
checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1")
4242
checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2")
4343
early_stopping = EarlyStopping(monitor="foo")
@@ -48,9 +48,9 @@ def test_checkpoint_callbacks_are_last(tmp_path):
4848
# no model reference
4949
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
5050
assert trainer.callbacks == [
51-
progress_bar,
5251
lr_monitor,
5352
model_summary,
53+
progress_bar,
5454
checkpoint1,
5555
checkpoint2,
5656
]
@@ -62,9 +62,9 @@ def test_checkpoint_callbacks_are_last(tmp_path):
6262
cb_connector = _CallbackConnector(trainer)
6363
cb_connector._attach_model_callbacks()
6464
assert trainer.callbacks == [
65-
progress_bar,
6665
lr_monitor,
6766
model_summary,
67+
progress_bar,
6868
checkpoint1,
6969
checkpoint2,
7070
]
@@ -77,10 +77,10 @@ def test_checkpoint_callbacks_are_last(tmp_path):
7777
cb_connector = _CallbackConnector(trainer)
7878
cb_connector._attach_model_callbacks()
7979
assert trainer.callbacks == [
80-
progress_bar,
8180
lr_monitor,
8281
early_stopping,
8382
model_summary,
83+
progress_bar,
8484
checkpoint1,
8585
checkpoint2,
8686
]
@@ -95,10 +95,10 @@ def test_checkpoint_callbacks_are_last(tmp_path):
9595
cb_connector._attach_model_callbacks()
9696
assert trainer.callbacks == [
9797
batch_size_finder,
98-
progress_bar,
9998
lr_monitor,
10099
early_stopping,
101100
model_summary,
101+
progress_bar,
102102
checkpoint2,
103103
checkpoint1,
104104
]
@@ -200,7 +200,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
200200
trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")],
201201
model_callbacks=[early_stopping1],
202202
)
203-
assert trainer.callbacks == [progress_bar, early_stopping1]
203+
assert trainer.callbacks == [early_stopping1, progress_bar] # progress_bar should be last
204204

205205
# multiple callbacks of the same type in trainer
206206
trainer = _attach_callbacks(
@@ -225,7 +225,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
225225
],
226226
model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2],
227227
)
228-
assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2]
228+
assert trainer.callbacks == [early_stopping1, lr_monitor, grad_accumulation, early_stopping2, progress_bar]
229229

230230
class CustomProgressBar(TQDMProgressBar): ...
231231

0 commit comments

Comments
 (0)