Skip to content

Commit dbf113d

Browse files
committed
test: disable rich progress bar in callback tests to ensure tqdm usage
1 parent de27bfe commit dbf113d

File tree

3 files changed

+277
-270
lines changed

3 files changed

+277
-270
lines changed

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 128 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections import defaultdict
1919
from typing import Union
2020
from unittest import mock
21-
from unittest.mock import ANY, Mock, PropertyMock, call
21+
from unittest.mock import ANY, Mock, PropertyMock, call, patch
2222

2323
import pytest
2424
import torch
@@ -112,120 +112,122 @@ def test_tqdm_progress_bar_misconfiguration():
112112
@pytest.mark.parametrize("num_dl", [1, 2])
113113
def test_tqdm_progress_bar_totals(tmp_path, num_dl):
114114
"""Test that the progress finishes with the correct total steps processed."""
115-
116-
class CustomModel(BoringModel):
117-
def _get_dataloaders(self):
118-
dls = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
119-
return dls[0] if num_dl == 1 else dls
120-
121-
def val_dataloader(self):
122-
return self._get_dataloaders()
123-
124-
def test_dataloader(self):
125-
return self._get_dataloaders()
126-
127-
def predict_dataloader(self):
128-
return self._get_dataloaders()
129-
130-
def validation_step(self, batch, batch_idx, dataloader_idx=0):
131-
return
132-
133-
def test_step(self, batch, batch_idx, dataloader_idx=0):
134-
return
135-
136-
def predict_step(self, batch, batch_idx, dataloader_idx=0):
137-
return
138-
139-
model = CustomModel()
140-
141-
# check the sanity dataloaders
142-
num_sanity_val_steps = 4
143-
trainer = Trainer(
144-
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps
145-
)
146-
pbar = trainer.progress_bar_callback
147-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
148-
trainer.fit(model)
149-
150-
expected_sanity_steps = [num_sanity_val_steps] * num_dl
151-
assert not pbar.val_progress_bar.leave
152-
assert trainer.num_sanity_val_batches == expected_sanity_steps
153-
assert pbar.val_progress_bar.total_values == expected_sanity_steps
154-
assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl
155-
assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)]
156-
157-
# fit
158-
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
159-
pbar = trainer.progress_bar_callback
160-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
161-
trainer.fit(model)
162-
163-
n = trainer.num_training_batches
164-
m = trainer.num_val_batches
165-
assert len(trainer.train_dataloader) == n
166-
# train progress bar should have reached the end
167-
assert pbar.train_progress_bar.total == n
168-
assert pbar.train_progress_bar.n == n
169-
assert pbar.train_progress_bar.leave
170-
171-
# check val progress bar total
172-
assert pbar.val_progress_bar.total_values == m
173-
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
174-
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
175-
assert not pbar.val_progress_bar.leave
176-
177-
# validate
178-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
179-
trainer.validate(model)
180-
assert trainer.num_val_batches == m
181-
assert pbar.val_progress_bar.total_values == m
182-
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
183-
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
184-
185-
# test
186-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
187-
trainer.test(model)
188-
assert pbar.test_progress_bar.leave
189-
k = trainer.num_test_batches
190-
assert pbar.test_progress_bar.total_values == k
191-
assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
192-
assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)]
193-
assert pbar.test_progress_bar.leave
194-
195-
# predict
196-
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
197-
trainer.predict(model)
198-
assert pbar.predict_progress_bar.leave
199-
k = trainer.num_predict_batches
200-
assert pbar.predict_progress_bar.total_values == k
201-
assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
202-
assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)]
115+
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
116+
117+
class CustomModel(BoringModel):
118+
def _get_dataloaders(self):
119+
dls = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
120+
return dls[0] if num_dl == 1 else dls
121+
122+
def val_dataloader(self):
123+
return self._get_dataloaders()
124+
125+
def test_dataloader(self):
126+
return self._get_dataloaders()
127+
128+
def predict_dataloader(self):
129+
return self._get_dataloaders()
130+
131+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
132+
return
133+
134+
def test_step(self, batch, batch_idx, dataloader_idx=0):
135+
return
136+
137+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
138+
return
139+
140+
model = CustomModel()
141+
142+
# check the sanity dataloaders
143+
num_sanity_val_steps = 4
144+
trainer = Trainer(
145+
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps
146+
)
147+
pbar = trainer.progress_bar_callback
148+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
149+
trainer.fit(model)
150+
151+
expected_sanity_steps = [num_sanity_val_steps] * num_dl
152+
assert not pbar.val_progress_bar.leave
153+
assert trainer.num_sanity_val_batches == expected_sanity_steps
154+
assert pbar.val_progress_bar.total_values == expected_sanity_steps
155+
assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl
156+
assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)]
157+
158+
# fit
159+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
160+
pbar = trainer.progress_bar_callback
161+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
162+
trainer.fit(model)
163+
164+
n = trainer.num_training_batches
165+
m = trainer.num_val_batches
166+
assert len(trainer.train_dataloader) == n
167+
# train progress bar should have reached the end
168+
assert pbar.train_progress_bar.total == n
169+
assert pbar.train_progress_bar.n == n
170+
assert pbar.train_progress_bar.leave
171+
172+
# check val progress bar total
173+
assert pbar.val_progress_bar.total_values == m
174+
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
175+
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
176+
assert not pbar.val_progress_bar.leave
177+
178+
# validate
179+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
180+
trainer.validate(model)
181+
assert trainer.num_val_batches == m
182+
assert pbar.val_progress_bar.total_values == m
183+
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
184+
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
185+
186+
# test
187+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
188+
trainer.test(model)
189+
assert pbar.test_progress_bar.leave
190+
k = trainer.num_test_batches
191+
assert pbar.test_progress_bar.total_values == k
192+
assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
193+
assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)]
194+
assert pbar.test_progress_bar.leave
195+
196+
# predict
197+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
198+
trainer.predict(model)
199+
assert pbar.predict_progress_bar.leave
200+
k = trainer.num_predict_batches
201+
assert pbar.predict_progress_bar.total_values == k
202+
assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
203+
assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)]
203204
assert pbar.predict_progress_bar.leave
204205

205206

206207
def test_tqdm_progress_bar_fast_dev_run(tmp_path):
207-
model = BoringModel()
208+
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
209+
model = BoringModel()
208210

209-
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True)
211+
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True)
210212

211-
trainer.fit(model)
213+
trainer.fit(model)
212214

213-
pbar = trainer.progress_bar_callback
215+
pbar = trainer.progress_bar_callback
214216

215-
assert pbar.val_progress_bar.n == 1
216-
assert pbar.val_progress_bar.total == 1
217+
assert pbar.val_progress_bar.n == 1
218+
assert pbar.val_progress_bar.total == 1
217219

218-
# the train progress bar should display 1 batch
219-
assert pbar.train_progress_bar.total == 1
220-
assert pbar.train_progress_bar.n == 1
220+
# the train progress bar should display 1 batch
221+
assert pbar.train_progress_bar.total == 1
222+
assert pbar.train_progress_bar.n == 1
221223

222-
trainer.validate(model)
224+
trainer.validate(model)
223225

224-
# the validation progress bar should display 1 batch
225-
assert pbar.val_progress_bar.total == 1
226-
assert pbar.val_progress_bar.n == 1
226+
# the validation progress bar should display 1 batch
227+
assert pbar.val_progress_bar.total == 1
228+
assert pbar.val_progress_bar.n == 1
227229

228-
trainer.test(model)
230+
trainer.test(model)
229231

230232
# the test progress bar should display 1 batch
231233
assert pbar.test_progress_bar.total == 1
@@ -325,14 +327,15 @@ def test_tqdm_progress_bar_default_value(tmp_path):
325327
@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
326328
def test_tqdm_progress_bar_value_on_colab(tmp_path):
327329
"""Test that Trainer will override the default in Google COLAB."""
328-
trainer = Trainer(default_root_dir=tmp_path)
329-
assert trainer.progress_bar_callback.refresh_rate == 20
330+
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
331+
trainer = Trainer(default_root_dir=tmp_path)
332+
assert trainer.progress_bar_callback.refresh_rate == 20
330333

331-
trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar())
332-
assert trainer.progress_bar_callback.refresh_rate == 20
334+
trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar())
335+
assert trainer.progress_bar_callback.refresh_rate == 20
333336

334-
trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar(refresh_rate=19))
335-
assert trainer.progress_bar_callback.refresh_rate == 19
337+
trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar(refresh_rate=19))
338+
assert trainer.progress_bar_callback.refresh_rate == 19
336339

337340

338341
@pytest.mark.parametrize(
@@ -413,21 +416,22 @@ def test_test_progress_bar_update_amount(tmp_path, test_batches: int, refresh_ra
413416

414417
def test_tensor_to_float_conversion(tmp_path):
415418
"""Check tensor gets converted to float."""
416-
417-
class TestModel(BoringModel):
418-
def training_step(self, batch, batch_idx):
419-
self.log("a", torch.tensor(0.123), prog_bar=True, on_epoch=False)
420-
self.log("b", torch.tensor([1]), prog_bar=True, on_epoch=False)
421-
self.log("c", 2, prog_bar=True, on_epoch=False)
422-
return super().training_step(batch, batch_idx)
423-
424-
trainer = Trainer(
425-
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
426-
)
427-
trainer.fit(TestModel())
428-
429-
torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123)
430-
assert trainer.progress_bar_metrics["b"] == 1.0
419+
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
420+
421+
class TestModel(BoringModel):
422+
def training_step(self, batch, batch_idx):
423+
self.log("a", torch.tensor(0.123), prog_bar=True, on_epoch=False)
424+
self.log("b", torch.tensor([1]), prog_bar=True, on_epoch=False)
425+
self.log("c", 2, prog_bar=True, on_epoch=False)
426+
return super().training_step(batch, batch_idx)
427+
428+
trainer = Trainer(
429+
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
430+
)
431+
trainer.fit(TestModel())
432+
433+
torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123)
434+
assert trainer.progress_bar_metrics["b"] == 1.0
431435
assert trainer.progress_bar_metrics["c"] == 2.0
432436
pbar = trainer.progress_bar_callback.train_progress_bar
433437
actual = str(pbar.postfix)

tests/tests_pytorch/callbacks/test_callbacks.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from pathlib import Path
1515
from re import escape
16-
from unittest.mock import Mock
16+
from unittest.mock import Mock, patch
1717

1818
import pytest
1919
from lightning_utilities.test.warning import no_warning_call
@@ -122,17 +122,18 @@ def load_state_dict(self, state_dict) -> None:
122122
def test_resume_callback_state_saved_by_type_stateful(tmp_path):
123123
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded, using
124124
state_dict/load_state_dict."""
125-
model = BoringModel()
126-
callback = OldStatefulCallback(state=111)
127-
trainer = Trainer(default_root_dir=tmp_path, max_steps=1, callbacks=[callback])
128-
trainer.fit(model)
129-
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
130-
assert ckpt_path.exists()
131-
132-
callback = OldStatefulCallback(state=222)
133-
trainer = Trainer(default_root_dir=tmp_path, max_steps=2, callbacks=[callback])
134-
trainer.fit(model, ckpt_path=ckpt_path)
135-
assert callback.state == 111
125+
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
126+
model = BoringModel()
127+
callback = OldStatefulCallback(state=111)
128+
trainer = Trainer(default_root_dir=tmp_path, max_steps=1, callbacks=[callback])
129+
trainer.fit(model)
130+
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
131+
assert ckpt_path.exists()
132+
133+
callback = OldStatefulCallback(state=222)
134+
trainer = Trainer(default_root_dir=tmp_path, max_steps=2, callbacks=[callback])
135+
trainer.fit(model, ckpt_path=ckpt_path)
136+
assert callback.state == 111
136137

137138

138139
def test_resume_incomplete_callbacks_list_warning(tmp_path):

0 commit comments

Comments
 (0)