|
35 | 35 | from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
|
36 | 36 |
|
37 | 37 |
|
| 38 | +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) |
38 | 39 | def test_checkpoint_callbacks_are_last(tmp_path):
|
39 | 40 | """Test that checkpoint callbacks always come last."""
|
40 |
| - with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
41 |
| - checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1") |
42 |
| - checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2") |
43 |
| - early_stopping = EarlyStopping(monitor="foo") |
44 |
| - lr_monitor = LearningRateMonitor() |
45 |
| - model_summary = ModelSummary() |
46 |
| - progress_bar = TQDMProgressBar() |
47 |
| - |
48 |
| - # no model reference |
49 |
| - trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2]) |
50 |
| - assert trainer.callbacks == [ |
51 |
| - progress_bar, |
52 |
| - lr_monitor, |
53 |
| - model_summary, |
54 |
| - checkpoint1, |
55 |
| - checkpoint2, |
56 |
| - ] |
57 |
| - |
58 |
| - # no model callbacks |
59 |
| - model = LightningModule() |
60 |
| - model.configure_callbacks = lambda: [] |
61 |
| - trainer.strategy._lightning_module = model |
62 |
| - cb_connector = _CallbackConnector(trainer) |
63 |
| - cb_connector._attach_model_callbacks() |
64 |
| - assert trainer.callbacks == [ |
65 |
| - progress_bar, |
66 |
| - lr_monitor, |
67 |
| - model_summary, |
68 |
| - checkpoint1, |
69 |
| - checkpoint2, |
70 |
| - ] |
71 |
| - |
72 |
| - # with model-specific callbacks that substitute ones in Trainer |
73 |
| - model = LightningModule() |
74 |
| - model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2] |
75 |
| - trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path, filename="ckpt_trainer")]) |
76 |
| - trainer.strategy._lightning_module = model |
77 |
| - cb_connector = _CallbackConnector(trainer) |
78 |
| - cb_connector._attach_model_callbacks() |
79 |
| - assert trainer.callbacks == [ |
80 |
| - progress_bar, |
81 |
| - lr_monitor, |
82 |
| - early_stopping, |
83 |
| - model_summary, |
84 |
| - checkpoint1, |
85 |
| - checkpoint2, |
86 |
| - ] |
87 |
| - |
88 |
| - # with tuner-specific callbacks that substitute ones in Trainer |
89 |
| - model = LightningModule() |
90 |
| - batch_size_finder = BatchSizeFinder() |
91 |
| - model.configure_callbacks = lambda: [checkpoint2, early_stopping, batch_size_finder, model_summary, checkpoint1] |
92 |
| - trainer = Trainer(callbacks=[progress_bar, lr_monitor]) |
93 |
| - trainer.strategy._lightning_module = model |
94 |
| - cb_connector = _CallbackConnector(trainer) |
95 |
| - cb_connector._attach_model_callbacks() |
96 |
| - assert trainer.callbacks == [ |
97 |
| - batch_size_finder, |
98 |
| - progress_bar, |
99 |
| - lr_monitor, |
100 |
| - early_stopping, |
101 |
| - model_summary, |
102 |
| - checkpoint2, |
103 |
| - checkpoint1, |
104 |
| - ] |
| 41 | + checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1") |
| 42 | + checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2") |
| 43 | + early_stopping = EarlyStopping(monitor="foo") |
| 44 | + lr_monitor = LearningRateMonitor() |
| 45 | + model_summary = ModelSummary() |
| 46 | + progress_bar = TQDMProgressBar() |
| 47 | + |
| 48 | + # no model reference |
| 49 | + trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2]) |
| 50 | + assert trainer.callbacks == [ |
| 51 | + progress_bar, |
| 52 | + lr_monitor, |
| 53 | + model_summary, |
| 54 | + checkpoint1, |
| 55 | + checkpoint2, |
| 56 | + ] |
| 57 | + |
| 58 | + # no model callbacks |
| 59 | + model = LightningModule() |
| 60 | + model.configure_callbacks = lambda: [] |
| 61 | + trainer.strategy._lightning_module = model |
| 62 | + cb_connector = _CallbackConnector(trainer) |
| 63 | + cb_connector._attach_model_callbacks() |
| 64 | + assert trainer.callbacks == [ |
| 65 | + progress_bar, |
| 66 | + lr_monitor, |
| 67 | + model_summary, |
| 68 | + checkpoint1, |
| 69 | + checkpoint2, |
| 70 | + ] |
| 71 | + |
| 72 | + # with model-specific callbacks that substitute ones in Trainer |
| 73 | + model = LightningModule() |
| 74 | + model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2] |
| 75 | + trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path, filename="ckpt_trainer")]) |
| 76 | + trainer.strategy._lightning_module = model |
| 77 | + cb_connector = _CallbackConnector(trainer) |
| 78 | + cb_connector._attach_model_callbacks() |
| 79 | + assert trainer.callbacks == [ |
| 80 | + progress_bar, |
| 81 | + lr_monitor, |
| 82 | + early_stopping, |
| 83 | + model_summary, |
| 84 | + checkpoint1, |
| 85 | + checkpoint2, |
| 86 | + ] |
| 87 | + |
| 88 | + # with tuner-specific callbacks that substitute ones in Trainer |
| 89 | + model = LightningModule() |
| 90 | + batch_size_finder = BatchSizeFinder() |
| 91 | + model.configure_callbacks = lambda: [checkpoint2, early_stopping, batch_size_finder, model_summary, checkpoint1] |
| 92 | + trainer = Trainer(callbacks=[progress_bar, lr_monitor]) |
| 93 | + trainer.strategy._lightning_module = model |
| 94 | + cb_connector = _CallbackConnector(trainer) |
| 95 | + cb_connector._attach_model_callbacks() |
| 96 | + assert trainer.callbacks == [ |
| 97 | + batch_size_finder, |
| 98 | + progress_bar, |
| 99 | + lr_monitor, |
| 100 | + early_stopping, |
| 101 | + model_summary, |
| 102 | + checkpoint2, |
| 103 | + checkpoint1, |
| 104 | + ] |
105 | 105 |
|
106 | 106 |
|
107 | 107 | class StatefulCallback0(Callback):
|
@@ -162,81 +162,81 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
|
162 | 162 | )
|
163 | 163 |
|
164 | 164 |
|
| 165 | +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) |
165 | 166 | def test_attach_model_callbacks():
|
166 | 167 | """Test that the callbacks defined in the model and through Trainer get merged correctly."""
|
167 |
| - with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
168 |
| - |
169 |
| - def _attach_callbacks(trainer_callbacks, model_callbacks): |
170 |
| - model = LightningModule() |
171 |
| - model.configure_callbacks = lambda: model_callbacks |
172 |
| - has_progress_bar = any(isinstance(cb, ProgressBar) for cb in trainer_callbacks + model_callbacks) |
173 |
| - trainer = Trainer( |
174 |
| - enable_checkpointing=False, |
175 |
| - enable_progress_bar=has_progress_bar, |
176 |
| - enable_model_summary=False, |
177 |
| - callbacks=trainer_callbacks, |
178 |
| - ) |
179 |
| - trainer.strategy._lightning_module = model |
180 |
| - cb_connector = _CallbackConnector(trainer) |
181 |
| - cb_connector._attach_model_callbacks() |
182 |
| - return trainer |
183 |
| - |
184 |
| - early_stopping1 = EarlyStopping(monitor="red") |
185 |
| - early_stopping2 = EarlyStopping(monitor="blue") |
186 |
| - progress_bar = TQDMProgressBar() |
187 |
| - lr_monitor = LearningRateMonitor() |
188 |
| - grad_accumulation = GradientAccumulationScheduler({1: 1}) |
189 |
| - |
190 |
| - # no callbacks |
191 |
| - trainer = _attach_callbacks(trainer_callbacks=[], model_callbacks=[]) |
192 |
| - assert trainer.callbacks == [] |
193 |
| - |
194 |
| - # callbacks of different types |
195 |
| - trainer = _attach_callbacks(trainer_callbacks=[early_stopping1], model_callbacks=[progress_bar]) |
196 |
| - assert trainer.callbacks == [early_stopping1, progress_bar] |
197 |
| - |
198 |
| - # same callback type twice, different instance |
199 |
| - trainer = _attach_callbacks( |
200 |
| - trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")], |
201 |
| - model_callbacks=[early_stopping1], |
202 |
| - ) |
203 |
| - assert trainer.callbacks == [progress_bar, early_stopping1] |
204 |
| - |
205 |
| - # multiple callbacks of the same type in trainer |
206 |
| - trainer = _attach_callbacks( |
207 |
| - trainer_callbacks=[ |
208 |
| - LearningRateMonitor(), |
209 |
| - EarlyStopping(monitor="yellow"), |
210 |
| - LearningRateMonitor(), |
211 |
| - EarlyStopping(monitor="black"), |
212 |
| - ], |
213 |
| - model_callbacks=[early_stopping1, lr_monitor], |
214 |
| - ) |
215 |
| - assert trainer.callbacks == [early_stopping1, lr_monitor] |
216 |
| - |
217 |
| - # multiple callbacks of the same type, in both trainer and model |
218 |
| - trainer = _attach_callbacks( |
219 |
| - trainer_callbacks=[ |
220 |
| - LearningRateMonitor(), |
221 |
| - progress_bar, |
222 |
| - EarlyStopping(monitor="yellow"), |
223 |
| - LearningRateMonitor(), |
224 |
| - EarlyStopping(monitor="black"), |
225 |
| - ], |
226 |
| - model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2], |
| 168 | + |
| 169 | + def _attach_callbacks(trainer_callbacks, model_callbacks): |
| 170 | + model = LightningModule() |
| 171 | + model.configure_callbacks = lambda: model_callbacks |
| 172 | + has_progress_bar = any(isinstance(cb, ProgressBar) for cb in trainer_callbacks + model_callbacks) |
| 173 | + trainer = Trainer( |
| 174 | + enable_checkpointing=False, |
| 175 | + enable_progress_bar=has_progress_bar, |
| 176 | + enable_model_summary=False, |
| 177 | + callbacks=trainer_callbacks, |
227 | 178 | )
|
228 |
| - assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2] |
| 179 | + trainer.strategy._lightning_module = model |
| 180 | + cb_connector = _CallbackConnector(trainer) |
| 181 | + cb_connector._attach_model_callbacks() |
| 182 | + return trainer |
| 183 | + |
| 184 | + early_stopping1 = EarlyStopping(monitor="red") |
| 185 | + early_stopping2 = EarlyStopping(monitor="blue") |
| 186 | + progress_bar = TQDMProgressBar() |
| 187 | + lr_monitor = LearningRateMonitor() |
| 188 | + grad_accumulation = GradientAccumulationScheduler({1: 1}) |
| 189 | + |
| 190 | + # no callbacks |
| 191 | + trainer = _attach_callbacks(trainer_callbacks=[], model_callbacks=[]) |
| 192 | + assert trainer.callbacks == [] |
| 193 | + |
| 194 | + # callbacks of different types |
| 195 | + trainer = _attach_callbacks(trainer_callbacks=[early_stopping1], model_callbacks=[progress_bar]) |
| 196 | + assert trainer.callbacks == [early_stopping1, progress_bar] |
| 197 | + |
| 198 | + # same callback type twice, different instance |
| 199 | + trainer = _attach_callbacks( |
| 200 | + trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")], |
| 201 | + model_callbacks=[early_stopping1], |
| 202 | + ) |
| 203 | + assert trainer.callbacks == [progress_bar, early_stopping1] |
| 204 | + |
| 205 | + # multiple callbacks of the same type in trainer |
| 206 | + trainer = _attach_callbacks( |
| 207 | + trainer_callbacks=[ |
| 208 | + LearningRateMonitor(), |
| 209 | + EarlyStopping(monitor="yellow"), |
| 210 | + LearningRateMonitor(), |
| 211 | + EarlyStopping(monitor="black"), |
| 212 | + ], |
| 213 | + model_callbacks=[early_stopping1, lr_monitor], |
| 214 | + ) |
| 215 | + assert trainer.callbacks == [early_stopping1, lr_monitor] |
| 216 | + |
| 217 | + # multiple callbacks of the same type, in both trainer and model |
| 218 | + trainer = _attach_callbacks( |
| 219 | + trainer_callbacks=[ |
| 220 | + LearningRateMonitor(), |
| 221 | + progress_bar, |
| 222 | + EarlyStopping(monitor="yellow"), |
| 223 | + LearningRateMonitor(), |
| 224 | + EarlyStopping(monitor="black"), |
| 225 | + ], |
| 226 | + model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2], |
| 227 | + ) |
| 228 | + assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2] |
229 | 229 |
|
230 |
| - class CustomProgressBar(TQDMProgressBar): ... |
| 230 | + class CustomProgressBar(TQDMProgressBar): ... |
231 | 231 |
|
232 |
| - custom_progress_bar = CustomProgressBar() |
233 |
| - # a custom callback that overrides ours |
234 |
| - trainer = _attach_callbacks(trainer_callbacks=[progress_bar], model_callbacks=[custom_progress_bar]) |
235 |
| - assert trainer.callbacks == [custom_progress_bar] |
| 232 | + custom_progress_bar = CustomProgressBar() |
| 233 | + # a custom callback that overrides ours |
| 234 | + trainer = _attach_callbacks(trainer_callbacks=[progress_bar], model_callbacks=[custom_progress_bar]) |
| 235 | + assert trainer.callbacks == [custom_progress_bar] |
236 | 236 |
|
237 |
| - # edge case |
238 |
| - bare_callback = Callback() |
239 |
| - trainer = _attach_callbacks(trainer_callbacks=[bare_callback], model_callbacks=[custom_progress_bar]) |
| 237 | + # edge case |
| 238 | + bare_callback = Callback() |
| 239 | + trainer = _attach_callbacks(trainer_callbacks=[bare_callback], model_callbacks=[custom_progress_bar]) |
240 | 240 | assert trainer.callbacks == [bare_callback, custom_progress_bar]
|
241 | 241 |
|
242 | 242 |
|
|
0 commit comments