Skip to content

Commit 5fcca4e

Browse files
swethmandavaSwetha MandavaBordaSeanNarens-rog
authored
passing batch outputs to on_train_batch_end (#4369)
* passing batch outputs to on_train_batch_end * styling * updating epoch end logic * also condition on on_train_epoch_end hooks * more readable * pep8 * pep8 * readability suggestion accepted Co-authored-by: Jirka Borovec <[email protected]> * adding test_training_epoch_end_metrics_collection_on_override test * fix formatting * fix formatting Co-authored-by: Swetha Mandava <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Roger Shieh <[email protected]>
1 parent eee3b1a commit 5fcca4e

File tree

2 files changed

+79
-18
lines changed

2 files changed

+79
-18
lines changed

pytorch_lightning/trainer/training_loop.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,13 @@ def on_train_epoch_start(self, epoch):
226226
self.trainer.call_hook("on_epoch_start")
227227
self.trainer.call_hook("on_train_epoch_start")
228228

229-
def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
229+
def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx):
230230
# hook
231231
self.trainer.call_hook('on_batch_end')
232-
self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx)
232+
self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx)
233233

234234
# figure out what to track for epoch end
235-
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
235+
self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs)
236236

237237
# reset batch logger internals
238238
self.trainer.logger_connector.on_train_batch_end()
@@ -244,12 +244,27 @@ def reset_train_val_dataloaders(self, model):
244244
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
245245
self.trainer.reset_val_dataloader(model)
246246

247-
def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
247+
def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):
248+
248249
# track the outputs to reduce at the end of the epoch
249-
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
250+
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
251+
sample_output = opt_outputs[-1]
252+
253+
# decide if we need to reduce at the end of the epoch automatically
254+
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
255+
hook_overridden = (
256+
is_overridden("training_epoch_end", model=self.trainer.get_model()) or
257+
is_overridden("on_train_epoch_end", model=self.trainer.get_model())
258+
)
259+
260+
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
261+
if not(hook_overridden or auto_reduce_tng_result):
262+
continue
263+
250264
# with 1 step (no tbptt) don't use a sequence at epoch end
251265
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
252266
opt_outputs = opt_outputs[0]
267+
253268
epoch_output[opt_idx].append(opt_outputs)
254269

255270
def get_optimizers_iterable(self):
@@ -537,17 +552,14 @@ def run_training_epoch(self):
537552
if batch_output.signal == -1:
538553
break
539554

540-
# only track outputs when user implements training_epoch_end
541-
# otherwise we will build up unnecessary memory
542-
epoch_end_outputs = self.process_train_step_outputs(
555+
batch_end_outputs = self.process_train_step_outputs(
543556
batch_output.training_step_output_for_epoch_end,
544557
self.early_stopping_accumulator,
545558
self.checkpoint_accumulator,
546559
)
547-
548560
# hook
549561
# TODO: add outputs to batches
550-
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
562+
self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx)
551563

552564
# -----------------------------------------
553565
# SAVE METRICS TO LOGGERS
@@ -901,7 +913,7 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
901913
# the training step outputs a list per optimizer. The list contains the outputs at each time step
902914
# when no TBPTT is used, then the list has 1 item per batch
903915
# when TBPTT IS used, then the list has n items (1 per time step)
904-
epoch_end_outputs = []
916+
batch_end_outputs = []
905917
for optimizer_idx_outputs in all_train_step_outputs:
906918
# extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
907919
if len(optimizer_idx_outputs) == 0:
@@ -916,14 +928,9 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
916928
if isinstance(sample_output, dict) and "checkpoint_on" in sample_output:
917929
checkpoint_accumulator.accumulate(sample_output["checkpoint_on"])
918930

919-
# decide if we need to reduce at the end of the epoch automatically
920-
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
921-
922-
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
923-
if is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result:
924-
epoch_end_outputs.append(optimizer_idx_outputs)
931+
batch_end_outputs.append(optimizer_idx_outputs)
925932

926-
return epoch_end_outputs
933+
return batch_end_outputs
927934

928935
def prepare_optimizers(self):
929936
# in manual optimization we loop over all optimizers at once

tests/models/test_hooks.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import pytest
1919
import torch
2020

21+
2122
from pytorch_lightning import Trainer
2223
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
24+
import pytorch_lightning as pl
2325
from tests.base import BoringModel, EvalModelTemplate, RandomDataset
2426

2527

@@ -90,6 +92,58 @@ def training_epoch_end(self, outputs):
9092
assert metrics[f'epoch_metric_{i}'] == i
9193

9294

95+
def test_training_epoch_end_metrics_collection_on_override(tmpdir):
96+
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
97+
num_epochs = 1
98+
99+
class LoggingCallback(pl.Callback):
100+
101+
def on_train_epoch_end(self, trainer, pl_module):
102+
self.len_outputs = 0
103+
104+
def on_train_epoch_end(self, trainer, pl_module, outputs):
105+
self.len_outputs = len(outputs[0])
106+
107+
class OverriddenModel(EvalModelTemplate):
108+
109+
def on_train_epoch_start(self):
110+
self.num_train_batches = 0
111+
112+
def training_epoch_end(self, outputs): # Overridden
113+
pass
114+
return
115+
116+
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
117+
self.num_train_batches += 1
118+
119+
class NotOverriddenModel(EvalModelTemplate):
120+
121+
def on_train_epoch_start(self):
122+
self.num_train_batches = 0
123+
124+
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
125+
self.num_train_batches += 1
126+
127+
overridden_model = OverriddenModel()
128+
not_overridden_model = NotOverriddenModel()
129+
130+
callback = LoggingCallback()
131+
trainer = Trainer(
132+
max_epochs=num_epochs,
133+
default_root_dir=tmpdir,
134+
overfit_batches=2,
135+
callbacks=[callback],
136+
)
137+
138+
result = trainer.fit(overridden_model)
139+
assert callback.len_outputs == overridden_model.num_train_batches
140+
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden
141+
142+
result = trainer.fit(not_overridden_model)
143+
assert callback.len_outputs == 0
144+
# outputs from on_train_batch_end should be empty
145+
146+
93147
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
94148
def test_transfer_batch_hook():
95149

0 commit comments

Comments
 (0)