|
35 | 35 | from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector |
36 | 36 |
|
37 | 37 |
|
| 38 | +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) |
38 | 39 | def test_checkpoint_callbacks_are_last(tmp_path): |
39 | 40 | """Test that checkpoint callbacks always come last.""" |
40 | | - with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
41 | | - checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1") |
42 | | - checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2") |
43 | | - early_stopping = EarlyStopping(monitor="foo") |
44 | | - lr_monitor = LearningRateMonitor() |
45 | | - model_summary = ModelSummary() |
46 | | - progress_bar = TQDMProgressBar() |
47 | | - |
48 | | - # no model reference |
49 | | - trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2]) |
50 | | - assert trainer.callbacks == [ |
51 | | - progress_bar, |
52 | | - lr_monitor, |
53 | | - model_summary, |
54 | | - checkpoint1, |
55 | | - checkpoint2, |
56 | | - ] |
57 | | - |
58 | | - # no model callbacks |
59 | | - model = LightningModule() |
60 | | - model.configure_callbacks = lambda: [] |
61 | | - trainer.strategy._lightning_module = model |
62 | | - cb_connector = _CallbackConnector(trainer) |
63 | | - cb_connector._attach_model_callbacks() |
64 | | - assert trainer.callbacks == [ |
65 | | - progress_bar, |
66 | | - lr_monitor, |
67 | | - model_summary, |
68 | | - checkpoint1, |
69 | | - checkpoint2, |
70 | | - ] |
71 | | - |
72 | | - # with model-specific callbacks that substitute ones in Trainer |
73 | | - model = LightningModule() |
74 | | - model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2] |
75 | | - trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path, filename="ckpt_trainer")]) |
76 | | - trainer.strategy._lightning_module = model |
77 | | - cb_connector = _CallbackConnector(trainer) |
78 | | - cb_connector._attach_model_callbacks() |
79 | | - assert trainer.callbacks == [ |
80 | | - progress_bar, |
81 | | - lr_monitor, |
82 | | - early_stopping, |
83 | | - model_summary, |
84 | | - checkpoint1, |
85 | | - checkpoint2, |
86 | | - ] |
87 | | - |
88 | | - # with tuner-specific callbacks that substitute ones in Trainer |
89 | | - model = LightningModule() |
90 | | - batch_size_finder = BatchSizeFinder() |
91 | | - model.configure_callbacks = lambda: [checkpoint2, early_stopping, batch_size_finder, model_summary, checkpoint1] |
92 | | - trainer = Trainer(callbacks=[progress_bar, lr_monitor]) |
93 | | - trainer.strategy._lightning_module = model |
94 | | - cb_connector = _CallbackConnector(trainer) |
95 | | - cb_connector._attach_model_callbacks() |
96 | | - assert trainer.callbacks == [ |
97 | | - batch_size_finder, |
98 | | - progress_bar, |
99 | | - lr_monitor, |
100 | | - early_stopping, |
101 | | - model_summary, |
102 | | - checkpoint2, |
103 | | - checkpoint1, |
104 | | - ] |
| 41 | + checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1") |
| 42 | + checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2") |
| 43 | + early_stopping = EarlyStopping(monitor="foo") |
| 44 | + lr_monitor = LearningRateMonitor() |
| 45 | + model_summary = ModelSummary() |
| 46 | + progress_bar = TQDMProgressBar() |
| 47 | + |
| 48 | + # no model reference |
| 49 | + trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2]) |
| 50 | + assert trainer.callbacks == [ |
| 51 | + progress_bar, |
| 52 | + lr_monitor, |
| 53 | + model_summary, |
| 54 | + checkpoint1, |
| 55 | + checkpoint2, |
| 56 | + ] |
| 57 | + |
| 58 | + # no model callbacks |
| 59 | + model = LightningModule() |
| 60 | + model.configure_callbacks = lambda: [] |
| 61 | + trainer.strategy._lightning_module = model |
| 62 | + cb_connector = _CallbackConnector(trainer) |
| 63 | + cb_connector._attach_model_callbacks() |
| 64 | + assert trainer.callbacks == [ |
| 65 | + progress_bar, |
| 66 | + lr_monitor, |
| 67 | + model_summary, |
| 68 | + checkpoint1, |
| 69 | + checkpoint2, |
| 70 | + ] |
| 71 | + |
| 72 | + # with model-specific callbacks that substitute ones in Trainer |
| 73 | + model = LightningModule() |
| 74 | + model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2] |
| 75 | + trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path, filename="ckpt_trainer")]) |
| 76 | + trainer.strategy._lightning_module = model |
| 77 | + cb_connector = _CallbackConnector(trainer) |
| 78 | + cb_connector._attach_model_callbacks() |
| 79 | + assert trainer.callbacks == [ |
| 80 | + progress_bar, |
| 81 | + lr_monitor, |
| 82 | + early_stopping, |
| 83 | + model_summary, |
| 84 | + checkpoint1, |
| 85 | + checkpoint2, |
| 86 | + ] |
| 87 | + |
| 88 | + # with tuner-specific callbacks that substitute ones in Trainer |
| 89 | + model = LightningModule() |
| 90 | + batch_size_finder = BatchSizeFinder() |
| 91 | + model.configure_callbacks = lambda: [checkpoint2, early_stopping, batch_size_finder, model_summary, checkpoint1] |
| 92 | + trainer = Trainer(callbacks=[progress_bar, lr_monitor]) |
| 93 | + trainer.strategy._lightning_module = model |
| 94 | + cb_connector = _CallbackConnector(trainer) |
| 95 | + cb_connector._attach_model_callbacks() |
| 96 | + assert trainer.callbacks == [ |
| 97 | + batch_size_finder, |
| 98 | + progress_bar, |
| 99 | + lr_monitor, |
| 100 | + early_stopping, |
| 101 | + model_summary, |
| 102 | + checkpoint2, |
| 103 | + checkpoint1, |
| 104 | + ] |
105 | 105 |
|
106 | 106 |
|
107 | 107 | class StatefulCallback0(Callback): |
@@ -162,81 +162,81 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path): |
162 | 162 | ) |
163 | 163 |
|
164 | 164 |
|
| 165 | +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) |
165 | 166 | def test_attach_model_callbacks(): |
166 | 167 | """Test that the callbacks defined in the model and through Trainer get merged correctly.""" |
167 | | - with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
168 | | - |
169 | | - def _attach_callbacks(trainer_callbacks, model_callbacks): |
170 | | - model = LightningModule() |
171 | | - model.configure_callbacks = lambda: model_callbacks |
172 | | - has_progress_bar = any(isinstance(cb, ProgressBar) for cb in trainer_callbacks + model_callbacks) |
173 | | - trainer = Trainer( |
174 | | - enable_checkpointing=False, |
175 | | - enable_progress_bar=has_progress_bar, |
176 | | - enable_model_summary=False, |
177 | | - callbacks=trainer_callbacks, |
178 | | - ) |
179 | | - trainer.strategy._lightning_module = model |
180 | | - cb_connector = _CallbackConnector(trainer) |
181 | | - cb_connector._attach_model_callbacks() |
182 | | - return trainer |
183 | | - |
184 | | - early_stopping1 = EarlyStopping(monitor="red") |
185 | | - early_stopping2 = EarlyStopping(monitor="blue") |
186 | | - progress_bar = TQDMProgressBar() |
187 | | - lr_monitor = LearningRateMonitor() |
188 | | - grad_accumulation = GradientAccumulationScheduler({1: 1}) |
189 | | - |
190 | | - # no callbacks |
191 | | - trainer = _attach_callbacks(trainer_callbacks=[], model_callbacks=[]) |
192 | | - assert trainer.callbacks == [] |
193 | | - |
194 | | - # callbacks of different types |
195 | | - trainer = _attach_callbacks(trainer_callbacks=[early_stopping1], model_callbacks=[progress_bar]) |
196 | | - assert trainer.callbacks == [early_stopping1, progress_bar] |
197 | | - |
198 | | - # same callback type twice, different instance |
199 | | - trainer = _attach_callbacks( |
200 | | - trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")], |
201 | | - model_callbacks=[early_stopping1], |
202 | | - ) |
203 | | - assert trainer.callbacks == [progress_bar, early_stopping1] |
204 | | - |
205 | | - # multiple callbacks of the same type in trainer |
206 | | - trainer = _attach_callbacks( |
207 | | - trainer_callbacks=[ |
208 | | - LearningRateMonitor(), |
209 | | - EarlyStopping(monitor="yellow"), |
210 | | - LearningRateMonitor(), |
211 | | - EarlyStopping(monitor="black"), |
212 | | - ], |
213 | | - model_callbacks=[early_stopping1, lr_monitor], |
214 | | - ) |
215 | | - assert trainer.callbacks == [early_stopping1, lr_monitor] |
216 | | - |
217 | | - # multiple callbacks of the same type, in both trainer and model |
218 | | - trainer = _attach_callbacks( |
219 | | - trainer_callbacks=[ |
220 | | - LearningRateMonitor(), |
221 | | - progress_bar, |
222 | | - EarlyStopping(monitor="yellow"), |
223 | | - LearningRateMonitor(), |
224 | | - EarlyStopping(monitor="black"), |
225 | | - ], |
226 | | - model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2], |
| 168 | + |
| 169 | + def _attach_callbacks(trainer_callbacks, model_callbacks): |
| 170 | + model = LightningModule() |
| 171 | + model.configure_callbacks = lambda: model_callbacks |
| 172 | + has_progress_bar = any(isinstance(cb, ProgressBar) for cb in trainer_callbacks + model_callbacks) |
| 173 | + trainer = Trainer( |
| 174 | + enable_checkpointing=False, |
| 175 | + enable_progress_bar=has_progress_bar, |
| 176 | + enable_model_summary=False, |
| 177 | + callbacks=trainer_callbacks, |
227 | 178 | ) |
228 | | - assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2] |
| 179 | + trainer.strategy._lightning_module = model |
| 180 | + cb_connector = _CallbackConnector(trainer) |
| 181 | + cb_connector._attach_model_callbacks() |
| 182 | + return trainer |
| 183 | + |
| 184 | + early_stopping1 = EarlyStopping(monitor="red") |
| 185 | + early_stopping2 = EarlyStopping(monitor="blue") |
| 186 | + progress_bar = TQDMProgressBar() |
| 187 | + lr_monitor = LearningRateMonitor() |
| 188 | + grad_accumulation = GradientAccumulationScheduler({1: 1}) |
| 189 | + |
| 190 | + # no callbacks |
| 191 | + trainer = _attach_callbacks(trainer_callbacks=[], model_callbacks=[]) |
| 192 | + assert trainer.callbacks == [] |
| 193 | + |
| 194 | + # callbacks of different types |
| 195 | + trainer = _attach_callbacks(trainer_callbacks=[early_stopping1], model_callbacks=[progress_bar]) |
| 196 | + assert trainer.callbacks == [early_stopping1, progress_bar] |
| 197 | + |
| 198 | + # same callback type twice, different instance |
| 199 | + trainer = _attach_callbacks( |
| 200 | + trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")], |
| 201 | + model_callbacks=[early_stopping1], |
| 202 | + ) |
| 203 | + assert trainer.callbacks == [progress_bar, early_stopping1] |
| 204 | + |
| 205 | + # multiple callbacks of the same type in trainer |
| 206 | + trainer = _attach_callbacks( |
| 207 | + trainer_callbacks=[ |
| 208 | + LearningRateMonitor(), |
| 209 | + EarlyStopping(monitor="yellow"), |
| 210 | + LearningRateMonitor(), |
| 211 | + EarlyStopping(monitor="black"), |
| 212 | + ], |
| 213 | + model_callbacks=[early_stopping1, lr_monitor], |
| 214 | + ) |
| 215 | + assert trainer.callbacks == [early_stopping1, lr_monitor] |
| 216 | + |
| 217 | + # multiple callbacks of the same type, in both trainer and model |
| 218 | + trainer = _attach_callbacks( |
| 219 | + trainer_callbacks=[ |
| 220 | + LearningRateMonitor(), |
| 221 | + progress_bar, |
| 222 | + EarlyStopping(monitor="yellow"), |
| 223 | + LearningRateMonitor(), |
| 224 | + EarlyStopping(monitor="black"), |
| 225 | + ], |
| 226 | + model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2], |
| 227 | + ) |
| 228 | + assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2] |
229 | 229 |
|
230 | | - class CustomProgressBar(TQDMProgressBar): ... |
| 230 | + class CustomProgressBar(TQDMProgressBar): ... |
231 | 231 |
|
232 | | - custom_progress_bar = CustomProgressBar() |
233 | | - # a custom callback that overrides ours |
234 | | - trainer = _attach_callbacks(trainer_callbacks=[progress_bar], model_callbacks=[custom_progress_bar]) |
235 | | - assert trainer.callbacks == [custom_progress_bar] |
| 232 | + custom_progress_bar = CustomProgressBar() |
| 233 | + # a custom callback that overrides ours |
| 234 | + trainer = _attach_callbacks(trainer_callbacks=[progress_bar], model_callbacks=[custom_progress_bar]) |
| 235 | + assert trainer.callbacks == [custom_progress_bar] |
236 | 236 |
|
237 | | - # edge case |
238 | | - bare_callback = Callback() |
239 | | - trainer = _attach_callbacks(trainer_callbacks=[bare_callback], model_callbacks=[custom_progress_bar]) |
| 237 | + # edge case |
| 238 | + bare_callback = Callback() |
| 239 | + trainer = _attach_callbacks(trainer_callbacks=[bare_callback], model_callbacks=[custom_progress_bar]) |
240 | 240 | assert trainer.callbacks == [bare_callback, custom_progress_bar] |
241 | 241 |
|
242 | 242 |
|
|
0 commit comments