Skip to content

Commit 18d2ae8

Browse files
authored
Fix logging on_train_batch_end in a callback with multiple optimizers (#5521)
* Start with the failing test * Then fix the failing test * Update CHANGELOG
1 parent a56f745 commit 18d2ae8

File tree

3 files changed

+30
-26
lines changed

3 files changed

+30
-26
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579))
2525

2626

27+
- Fixed logging on_train_batch_end in a callback with multiple optimizers ([#5521](https://github.com/PyTorchLightning/pytorch-lightning/pull/5521))
28+
29+
2730
- Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519))
2831

2932

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,7 @@ def auto_reduce_results_on_epoch_end(self) -> None:
203203
epoch_metrics = self._internals[dl_idx]
204204

205205
if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP:
206-
207-
num_opt_idx = len(self._internals[dl_idx]) - 1
208-
209-
# Make sure we didn't create key
210-
assert num_opt_idx >= 0
211-
212-
for opt_idx in range(num_opt_idx + 1):
206+
for opt_idx in list(epoch_metrics):
213207
# TODO: Figure out to reduce memory
214208
# TODO: How to start training in middle of epoch
215209
opt_outputs = epoch_metrics[opt_idx]

tests/trainer/optimization/test_multiple_optimizers.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,18 @@
2222

2323
def test_unbalanced_logging_with_multiple_optimizers(tmpdir):
2424
"""
25-
This tests ensures reduction works in un-balanced logging settings
25+
This tests ensures reduction works in unbalanced logging settings,
26+
even when a Callback also logs.
2627
"""
2728
class TestModel(BoringModel):
28-
29-
loss_1 = []
30-
loss_2 = []
29+
actual = {0: [], 1: []}
3130

3231
def training_step(self, batch, batch_idx, optimizer_idx):
33-
output = self.layer(batch)
34-
loss = self.loss(batch, output)
35-
if optimizer_idx == 0 and self.trainer.global_step > 10:
36-
self.log("loss_1", loss, on_epoch=True, prog_bar=True)
37-
self.loss_1.append(loss.detach().clone())
38-
elif optimizer_idx == 1:
39-
self.log("loss_2", loss, on_epoch=True, prog_bar=True)
40-
self.loss_2.append(loss.detach().clone())
41-
return {"loss": loss}
32+
out = super().training_step(batch, batch_idx)
33+
loss = out["loss"]
34+
self.log(f"loss_{optimizer_idx}", loss, on_epoch=True)
35+
self.actual[optimizer_idx].append(loss)
36+
return out
4237

4338
def configure_optimizers(self):
4439
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
@@ -48,16 +43,28 @@ def configure_optimizers(self):
4843
model = TestModel()
4944
model.training_epoch_end = None
5045

46+
class TestCallback(pl.Callback):
47+
def on_train_batch_end(self, trainer, pl_module, output, batch, batch_idx, dl_idx):
48+
# when this is called, the EpochResultStore state has not been reset yet because we are still
49+
# "INSIDE_BATCH_TRAIN_LOOP" and the LoggerConnector runs its `on_train_batch_end` after the
50+
# Callback (see `TrainLoop.on_train_batch_end`). For this reason, opt_idx here is the index
51+
# of the last optimizer updated (the second, index 1). This produced a KeyError as reported in #5459
52+
pl_module.log("test_train_batch_end", trainer.logger_connector.cached_results._opt_idx)
53+
5154
# Initialize a trainer
5255
trainer = pl.Trainer(
5356
default_root_dir=tmpdir,
5457
max_epochs=1,
58+
limit_train_batches=5,
59+
limit_val_batches=5,
60+
callbacks=[TestCallback()],
61+
weights_summary=None,
5562
)
56-
5763
trainer.fit(model)
5864

59-
assert torch.equal(trainer.callback_metrics["loss_2_step"], model.loss_2[-1])
60-
assert torch.equal(trainer.callback_metrics["loss_1_step"], model.loss_1[-1])
61-
# test loss are properly reduced
62-
assert torch.abs(trainer.callback_metrics["loss_2_epoch"] - torch.FloatTensor(model.loss_2).mean()) < 1e-6
63-
assert torch.abs(trainer.callback_metrics["loss_1_epoch"] - torch.FloatTensor(model.loss_1).mean()) < 1e-6
65+
for k, v in model.actual.items():
66+
assert torch.equal(trainer.callback_metrics[f"loss_{k}_step"], v[-1])
67+
# test loss is properly reduced
68+
torch.testing.assert_allclose(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean())
69+
70+
assert trainer.callback_metrics["test_train_batch_end"] == len(model.optimizers()) - 1

0 commit comments

Comments
 (0)