Skip to content

Commit f8081b0

Browse files
committed
refactor: move patch decorators outside test functions in callback tests
1 parent 04b8870 commit f8081b0

File tree

3 files changed

+168
-168
lines changed

3 files changed

+168
-168
lines changed

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -204,30 +204,30 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
204204
assert pbar.predict_progress_bar.leave
205205

206206

207+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
207208
def test_tqdm_progress_bar_fast_dev_run(tmp_path):
208-
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
209-
model = BoringModel()
209+
model = BoringModel()
210210

211-
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True)
211+
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True)
212212

213-
trainer.fit(model)
213+
trainer.fit(model)
214214

215-
pbar = trainer.progress_bar_callback
215+
pbar = trainer.progress_bar_callback
216216

217-
assert pbar.val_progress_bar.n == 1
218-
assert pbar.val_progress_bar.total == 1
217+
assert pbar.val_progress_bar.n == 1
218+
assert pbar.val_progress_bar.total == 1
219219

220-
# the train progress bar should display 1 batch
221-
assert pbar.train_progress_bar.total == 1
222-
assert pbar.train_progress_bar.n == 1
220+
# the train progress bar should display 1 batch
221+
assert pbar.train_progress_bar.total == 1
222+
assert pbar.train_progress_bar.n == 1
223223

224-
trainer.validate(model)
224+
trainer.validate(model)
225225

226-
# the validation progress bar should display 1 batch
227-
assert pbar.val_progress_bar.total == 1
228-
assert pbar.val_progress_bar.n == 1
226+
# the validation progress bar should display 1 batch
227+
assert pbar.val_progress_bar.total == 1
228+
assert pbar.val_progress_bar.n == 1
229229

230-
trainer.test(model)
230+
trainer.test(model)
231231

232232
# the test progress bar should display 1 batch
233233
assert pbar.test_progress_bar.total == 1
@@ -325,17 +325,17 @@ def test_tqdm_progress_bar_default_value(tmp_path):
325325

326326

327327
@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
328+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
328329
def test_tqdm_progress_bar_value_on_colab(tmp_path):
329330
"""Test that Trainer will override the default in Google COLAB."""
330-
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
331-
trainer = Trainer(default_root_dir=tmp_path)
332-
assert trainer.progress_bar_callback.refresh_rate == 20
331+
trainer = Trainer(default_root_dir=tmp_path)
332+
assert trainer.progress_bar_callback.refresh_rate == 20
333333

334-
trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar())
335-
assert trainer.progress_bar_callback.refresh_rate == 20
334+
trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar())
335+
assert trainer.progress_bar_callback.refresh_rate == 20
336336

337-
trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar(refresh_rate=19))
338-
assert trainer.progress_bar_callback.refresh_rate == 19
337+
trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar(refresh_rate=19))
338+
assert trainer.progress_bar_callback.refresh_rate == 19
339339

340340

341341
@pytest.mark.parametrize(

tests/tests_pytorch/callbacks/test_callbacks.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,21 @@ def load_state_dict(self, state_dict) -> None:
119119
self.state = state_dict["state"]
120120

121121

122+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
122123
def test_resume_callback_state_saved_by_type_stateful(tmp_path):
123124
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded, using
124125
state_dict/load_state_dict."""
125-
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
126-
model = BoringModel()
127-
callback = OldStatefulCallback(state=111)
128-
trainer = Trainer(default_root_dir=tmp_path, max_steps=1, callbacks=[callback])
129-
trainer.fit(model)
130-
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
131-
assert ckpt_path.exists()
132-
133-
callback = OldStatefulCallback(state=222)
134-
trainer = Trainer(default_root_dir=tmp_path, max_steps=2, callbacks=[callback])
135-
trainer.fit(model, ckpt_path=ckpt_path)
136-
assert callback.state == 111
126+
model = BoringModel()
127+
callback = OldStatefulCallback(state=111)
128+
trainer = Trainer(default_root_dir=tmp_path, max_steps=1, callbacks=[callback])
129+
trainer.fit(model)
130+
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
131+
assert ckpt_path.exists()
132+
133+
callback = OldStatefulCallback(state=222)
134+
trainer = Trainer(default_root_dir=tmp_path, max_steps=2, callbacks=[callback])
135+
trainer.fit(model, ckpt_path=ckpt_path)
136+
assert callback.state == 111
137137

138138

139139
def test_resume_incomplete_callbacks_list_warning(tmp_path):

tests/tests_pytorch/trainer/connectors/test_callback_connector.py

Lines changed: 134 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -35,73 +35,73 @@
3535
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
3636

3737

38+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
3839
def test_checkpoint_callbacks_are_last(tmp_path):
3940
"""Test that checkpoint callbacks always come last."""
40-
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
41-
checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1")
42-
checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2")
43-
early_stopping = EarlyStopping(monitor="foo")
44-
lr_monitor = LearningRateMonitor()
45-
model_summary = ModelSummary()
46-
progress_bar = TQDMProgressBar()
47-
48-
# no model reference
49-
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
50-
assert trainer.callbacks == [
51-
progress_bar,
52-
lr_monitor,
53-
model_summary,
54-
checkpoint1,
55-
checkpoint2,
56-
]
57-
58-
# no model callbacks
59-
model = LightningModule()
60-
model.configure_callbacks = lambda: []
61-
trainer.strategy._lightning_module = model
62-
cb_connector = _CallbackConnector(trainer)
63-
cb_connector._attach_model_callbacks()
64-
assert trainer.callbacks == [
65-
progress_bar,
66-
lr_monitor,
67-
model_summary,
68-
checkpoint1,
69-
checkpoint2,
70-
]
71-
72-
# with model-specific callbacks that substitute ones in Trainer
73-
model = LightningModule()
74-
model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2]
75-
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path, filename="ckpt_trainer")])
76-
trainer.strategy._lightning_module = model
77-
cb_connector = _CallbackConnector(trainer)
78-
cb_connector._attach_model_callbacks()
79-
assert trainer.callbacks == [
80-
progress_bar,
81-
lr_monitor,
82-
early_stopping,
83-
model_summary,
84-
checkpoint1,
85-
checkpoint2,
86-
]
87-
88-
# with tuner-specific callbacks that substitute ones in Trainer
89-
model = LightningModule()
90-
batch_size_finder = BatchSizeFinder()
91-
model.configure_callbacks = lambda: [checkpoint2, early_stopping, batch_size_finder, model_summary, checkpoint1]
92-
trainer = Trainer(callbacks=[progress_bar, lr_monitor])
93-
trainer.strategy._lightning_module = model
94-
cb_connector = _CallbackConnector(trainer)
95-
cb_connector._attach_model_callbacks()
96-
assert trainer.callbacks == [
97-
batch_size_finder,
98-
progress_bar,
99-
lr_monitor,
100-
early_stopping,
101-
model_summary,
102-
checkpoint2,
103-
checkpoint1,
104-
]
41+
checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1")
42+
checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2")
43+
early_stopping = EarlyStopping(monitor="foo")
44+
lr_monitor = LearningRateMonitor()
45+
model_summary = ModelSummary()
46+
progress_bar = TQDMProgressBar()
47+
48+
# no model reference
49+
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
50+
assert trainer.callbacks == [
51+
progress_bar,
52+
lr_monitor,
53+
model_summary,
54+
checkpoint1,
55+
checkpoint2,
56+
]
57+
58+
# no model callbacks
59+
model = LightningModule()
60+
model.configure_callbacks = lambda: []
61+
trainer.strategy._lightning_module = model
62+
cb_connector = _CallbackConnector(trainer)
63+
cb_connector._attach_model_callbacks()
64+
assert trainer.callbacks == [
65+
progress_bar,
66+
lr_monitor,
67+
model_summary,
68+
checkpoint1,
69+
checkpoint2,
70+
]
71+
72+
# with model-specific callbacks that substitute ones in Trainer
73+
model = LightningModule()
74+
model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2]
75+
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path, filename="ckpt_trainer")])
76+
trainer.strategy._lightning_module = model
77+
cb_connector = _CallbackConnector(trainer)
78+
cb_connector._attach_model_callbacks()
79+
assert trainer.callbacks == [
80+
progress_bar,
81+
lr_monitor,
82+
early_stopping,
83+
model_summary,
84+
checkpoint1,
85+
checkpoint2,
86+
]
87+
88+
# with tuner-specific callbacks that substitute ones in Trainer
89+
model = LightningModule()
90+
batch_size_finder = BatchSizeFinder()
91+
model.configure_callbacks = lambda: [checkpoint2, early_stopping, batch_size_finder, model_summary, checkpoint1]
92+
trainer = Trainer(callbacks=[progress_bar, lr_monitor])
93+
trainer.strategy._lightning_module = model
94+
cb_connector = _CallbackConnector(trainer)
95+
cb_connector._attach_model_callbacks()
96+
assert trainer.callbacks == [
97+
batch_size_finder,
98+
progress_bar,
99+
lr_monitor,
100+
early_stopping,
101+
model_summary,
102+
checkpoint2,
103+
checkpoint1,
104+
]
105105

106106

107107
class StatefulCallback0(Callback):
@@ -162,81 +162,81 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
162162
)
163163

164164

165+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
165166
def test_attach_model_callbacks():
166167
"""Test that the callbacks defined in the model and through Trainer get merged correctly."""
167-
with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False):
168-
169-
def _attach_callbacks(trainer_callbacks, model_callbacks):
170-
model = LightningModule()
171-
model.configure_callbacks = lambda: model_callbacks
172-
has_progress_bar = any(isinstance(cb, ProgressBar) for cb in trainer_callbacks + model_callbacks)
173-
trainer = Trainer(
174-
enable_checkpointing=False,
175-
enable_progress_bar=has_progress_bar,
176-
enable_model_summary=False,
177-
callbacks=trainer_callbacks,
178-
)
179-
trainer.strategy._lightning_module = model
180-
cb_connector = _CallbackConnector(trainer)
181-
cb_connector._attach_model_callbacks()
182-
return trainer
183-
184-
early_stopping1 = EarlyStopping(monitor="red")
185-
early_stopping2 = EarlyStopping(monitor="blue")
186-
progress_bar = TQDMProgressBar()
187-
lr_monitor = LearningRateMonitor()
188-
grad_accumulation = GradientAccumulationScheduler({1: 1})
189-
190-
# no callbacks
191-
trainer = _attach_callbacks(trainer_callbacks=[], model_callbacks=[])
192-
assert trainer.callbacks == []
193-
194-
# callbacks of different types
195-
trainer = _attach_callbacks(trainer_callbacks=[early_stopping1], model_callbacks=[progress_bar])
196-
assert trainer.callbacks == [early_stopping1, progress_bar]
197-
198-
# same callback type twice, different instance
199-
trainer = _attach_callbacks(
200-
trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")],
201-
model_callbacks=[early_stopping1],
202-
)
203-
assert trainer.callbacks == [progress_bar, early_stopping1]
204-
205-
# multiple callbacks of the same type in trainer
206-
trainer = _attach_callbacks(
207-
trainer_callbacks=[
208-
LearningRateMonitor(),
209-
EarlyStopping(monitor="yellow"),
210-
LearningRateMonitor(),
211-
EarlyStopping(monitor="black"),
212-
],
213-
model_callbacks=[early_stopping1, lr_monitor],
214-
)
215-
assert trainer.callbacks == [early_stopping1, lr_monitor]
216-
217-
# multiple callbacks of the same type, in both trainer and model
218-
trainer = _attach_callbacks(
219-
trainer_callbacks=[
220-
LearningRateMonitor(),
221-
progress_bar,
222-
EarlyStopping(monitor="yellow"),
223-
LearningRateMonitor(),
224-
EarlyStopping(monitor="black"),
225-
],
226-
model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2],
168+
169+
def _attach_callbacks(trainer_callbacks, model_callbacks):
170+
model = LightningModule()
171+
model.configure_callbacks = lambda: model_callbacks
172+
has_progress_bar = any(isinstance(cb, ProgressBar) for cb in trainer_callbacks + model_callbacks)
173+
trainer = Trainer(
174+
enable_checkpointing=False,
175+
enable_progress_bar=has_progress_bar,
176+
enable_model_summary=False,
177+
callbacks=trainer_callbacks,
227178
)
228-
assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2]
179+
trainer.strategy._lightning_module = model
180+
cb_connector = _CallbackConnector(trainer)
181+
cb_connector._attach_model_callbacks()
182+
return trainer
183+
184+
early_stopping1 = EarlyStopping(monitor="red")
185+
early_stopping2 = EarlyStopping(monitor="blue")
186+
progress_bar = TQDMProgressBar()
187+
lr_monitor = LearningRateMonitor()
188+
grad_accumulation = GradientAccumulationScheduler({1: 1})
189+
190+
# no callbacks
191+
trainer = _attach_callbacks(trainer_callbacks=[], model_callbacks=[])
192+
assert trainer.callbacks == []
193+
194+
# callbacks of different types
195+
trainer = _attach_callbacks(trainer_callbacks=[early_stopping1], model_callbacks=[progress_bar])
196+
assert trainer.callbacks == [early_stopping1, progress_bar]
197+
198+
# same callback type twice, different instance
199+
trainer = _attach_callbacks(
200+
trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")],
201+
model_callbacks=[early_stopping1],
202+
)
203+
assert trainer.callbacks == [progress_bar, early_stopping1]
204+
205+
# multiple callbacks of the same type in trainer
206+
trainer = _attach_callbacks(
207+
trainer_callbacks=[
208+
LearningRateMonitor(),
209+
EarlyStopping(monitor="yellow"),
210+
LearningRateMonitor(),
211+
EarlyStopping(monitor="black"),
212+
],
213+
model_callbacks=[early_stopping1, lr_monitor],
214+
)
215+
assert trainer.callbacks == [early_stopping1, lr_monitor]
216+
217+
# multiple callbacks of the same type, in both trainer and model
218+
trainer = _attach_callbacks(
219+
trainer_callbacks=[
220+
LearningRateMonitor(),
221+
progress_bar,
222+
EarlyStopping(monitor="yellow"),
223+
LearningRateMonitor(),
224+
EarlyStopping(monitor="black"),
225+
],
226+
model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2],
227+
)
228+
assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2]
229229

230-
class CustomProgressBar(TQDMProgressBar): ...
230+
class CustomProgressBar(TQDMProgressBar): ...
231231

232-
custom_progress_bar = CustomProgressBar()
233-
# a custom callback that overrides ours
234-
trainer = _attach_callbacks(trainer_callbacks=[progress_bar], model_callbacks=[custom_progress_bar])
235-
assert trainer.callbacks == [custom_progress_bar]
232+
custom_progress_bar = CustomProgressBar()
233+
# a custom callback that overrides ours
234+
trainer = _attach_callbacks(trainer_callbacks=[progress_bar], model_callbacks=[custom_progress_bar])
235+
assert trainer.callbacks == [custom_progress_bar]
236236

237-
# edge case
238-
bare_callback = Callback()
239-
trainer = _attach_callbacks(trainer_callbacks=[bare_callback], model_callbacks=[custom_progress_bar])
237+
# edge case
238+
bare_callback = Callback()
239+
trainer = _attach_callbacks(trainer_callbacks=[bare_callback], model_callbacks=[custom_progress_bar])
240240
assert trainer.callbacks == [bare_callback, custom_progress_bar]
241241

242242

0 commit comments

Comments
 (0)