Skip to content

Commit ccfd1d8

Browse files
rohitgr7tchatoncarmocca
authored andcommitted
Fix support for logging within callbacks returned from LightningModule (#10991)
Co-authored-by: thomas chaton <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 93cda24 commit ccfd1d8

File tree

5 files changed

+30
-14
lines changed

5 files changed

+30
-14
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Fixed
1010

1111
- Fixed a bug where the DeepSpeedPlugin arguments `cpu_checkpointing` and `contiguous_memory_optimization` were not being forwarded to deepspeed correctly ([#10874](https://github.com/PyTorchLightning/pytorch-lightning/issues/10874))
12+
13+
1214
- Fixed an issue with `NeptuneLogger` causing checkpoints to be uploaded with a duplicated file extension ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/issues/11015))
15+
=======
16+
17+
18+
- Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991))
19+
20+
21+
-
22+
23+
24+
-
1325

1426

1527
## [1.5.5] - 2021-12-07

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
2020

2121

22-
def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
22+
def verify_loop_configurations(trainer: "pl.Trainer") -> None:
2323
r"""
2424
Checks that the model is configured correctly before the run is started.
2525
@@ -28,6 +28,10 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
2828
model: The model to check the configuration.
2929
3030
"""
31+
model = trainer.lightning_module
32+
33+
if trainer.state.fn is None:
34+
raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.")
3135
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
3236
__verify_train_val_loop_configuration(trainer, model)
3337
__verify_manual_optimization_support(trainer, model)

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,11 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
249249
def _trainer_has_checkpoint_callbacks(self):
250250
return len(self.trainer.checkpoint_callbacks) > 0
251251

252-
def attach_model_logging_functions(self, model):
252+
def _attach_model_logging_functions(self):
253+
lightning_module = self.trainer.lightning_module
253254
for callback in self.trainer.callbacks:
254-
callback.log = model.log
255-
callback.log_dict = model.log_dict
255+
callback.log = lightning_module.log
256+
callback.log_dict = lightning_module.log_dict
256257

257258
def _attach_model_callbacks(self) -> None:
258259
"""Attaches the callbacks defined in the model.

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,17 +1114,16 @@ def _run(
11141114
if hasattr(model, "hparams"):
11151115
parsing.clean_namespace(model.hparams)
11161116

1117-
verify_loop_configurations(self, model)
1118-
1119-
# attach model log function to callback
1120-
self._callback_connector.attach_model_logging_functions(model)
1121-
11221117
# attach model to the training type plugin
11231118
self.training_type_plugin.connect(model)
11241119

1120+
self._callback_connector._attach_model_callbacks()
1121+
self._callback_connector._attach_model_logging_functions()
1122+
1123+
verify_loop_configurations(self)
1124+
11251125
# hook
11261126
self._data_connector.prepare_data()
1127-
self._callback_connector._attach_model_callbacks()
11281127

11291128
# ----------------------------
11301129
# SET UP TRAINING

tests/models/test_hooks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,8 @@ def training_step(self, batch, batch_idx):
499499
expected = [
500500
dict(name="Callback.on_init_start", args=(trainer,)),
501501
dict(name="Callback.on_init_end", args=(trainer,)),
502-
dict(name="prepare_data"),
503502
dict(name="configure_callbacks"),
503+
dict(name="prepare_data"),
504504
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
505505
# DeepSpeed needs the batch size to figure out throughput logging
506506
*([dict(name="train_dataloader")] if kwargs.get("strategy") == "deepspeed" else []),
@@ -618,8 +618,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
618618
expected = [
619619
dict(name="Callback.on_init_start", args=(trainer,)),
620620
dict(name="Callback.on_init_end", args=(trainer,)),
621-
dict(name="prepare_data"),
622621
dict(name="configure_callbacks"),
622+
dict(name="prepare_data"),
623623
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
624624
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
625625
dict(name="setup", kwargs=dict(stage="fit")),
@@ -716,8 +716,8 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
716716
expected = [
717717
dict(name="Callback.on_init_start", args=(trainer,)),
718718
dict(name="Callback.on_init_end", args=(trainer,)),
719-
dict(name="prepare_data"),
720719
dict(name="configure_callbacks"),
720+
dict(name="prepare_data"),
721721
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
722722
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage=verb)),
723723
dict(name="setup", kwargs=dict(stage=verb)),
@@ -748,8 +748,8 @@ def test_trainer_model_hook_system_predict(tmpdir):
748748
expected = [
749749
dict(name="Callback.on_init_start", args=(trainer,)),
750750
dict(name="Callback.on_init_end", args=(trainer,)),
751-
dict(name="prepare_data"),
752751
dict(name="configure_callbacks"),
752+
dict(name="prepare_data"),
753753
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
754754
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="predict")),
755755
dict(name="setup", kwargs=dict(stage="predict")),

0 commit comments

Comments
 (0)