Skip to content

Commit e429f97

Browse files
rohitgr7tchatoncarmoccamergify[bot]Borda
authored
Separate epoch validation from step validation (#5208)
* Seperate epoch validaton from step validation * update system * test * baked logic in callbacks * unbake logic in callbacks * fix the call for scheduler * use property * pep * correct rebase * gitignore * ref * add tests * fix * add early stopping test * trigger * chlog * rev * 1.3 * log * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/trainer/training_loop.py * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: chaton <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 3b7afb9 commit e429f97

File tree

13 files changed

+194
-85
lines changed

13 files changed

+194
-85
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,6 @@ pytorch\ lightning
145145
test-reports/
146146
wandb
147147
.forked/
148+
149+
# ctags
150+
tags

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7-
## [1.1.8] - 2021-02-06
7+
## [1.1.8] - 2021-02-08
88

99
### Fixed
1010

11+
- Separate epoch validation from step validation ([#5208](https://github.com/PyTorchLightning/pytorch-lightning/pull/5208))
1112
- Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775))
1213

14+
1315
## [1.1.7] - 2021-02-03
1416

1517
### Fixed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,6 @@ def __init__(
8888
self.stopped_epoch = 0
8989
self.mode = mode
9090
self.warned_result_obj = False
91-
# Indicates, if eval results are used as basis for early stopping
92-
# It is set to False initially and overwritten, if eval results have been validated
93-
self.based_on_eval_results = False
9491

9592
self.__init_monitor_mode()
9693

@@ -164,21 +161,6 @@ def on_validation_end(self, trainer, pl_module):
164161

165162
self._run_early_stopping_check(trainer, pl_module)
166163

167-
def on_validation_epoch_end(self, trainer, pl_module):
168-
if trainer.fast_dev_run or trainer.running_sanity_check:
169-
return
170-
171-
if self._validate_condition_metric(trainer.callback_metrics):
172-
# turn off early stopping in on_train_epoch_end
173-
self.based_on_eval_results = True
174-
175-
def on_train_epoch_end(self, trainer, pl_module, outputs):
176-
# disable early stopping in train loop when there's a val loop
177-
if self.based_on_eval_results:
178-
return
179-
180-
self._run_early_stopping_check(trainer, pl_module)
181-
182164
def _run_early_stopping_check(self, trainer, pl_module):
183165
"""
184166
Checks whether the early stopping condition is met

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(
166166
self.save_top_k = save_top_k
167167
self.save_weights_only = save_weights_only
168168
self.period = period
169-
self.last_global_step_saved = -1
169+
self._last_global_step_saved = -1
170170
self.prefix = prefix
171171
self.current_score = None
172172
self.best_k_models = {}
@@ -231,15 +231,15 @@ def save_checkpoint(self, trainer, pl_module):
231231
or self.period < 1 # no models are saved
232232
or (epoch + 1) % self.period # skip epoch
233233
or trainer.running_sanity_check # don't save anything during sanity check
234-
or self.last_global_step_saved == global_step # already saved at the last step
234+
or self._last_global_step_saved == global_step # already saved at the last step
235235
):
236236
return
237237

238238
self._add_backward_monitor_support(trainer)
239239
self._validate_monitor_key(trainer)
240240

241241
# track epoch when ckpt was last checked
242-
self.last_global_step_saved = global_step
242+
self._last_global_step_saved = global_step
243243

244244
# what can be monitored
245245
monitor_candidates = self._monitor_candidates(trainer)

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
400400
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
401401
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
402402
rank_zero_warn(
403-
'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
403+
'Warning, `hyper_parameters` dropped from checkpoint.'
404+
f' An attribute is not picklable {err}'
404405
)
405406
atomic_save(checkpoint, filepath)

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,8 @@ def get_evaluation_dataloaders(self, max_batches):
7171

7272
return dataloaders, max_batches
7373

74-
def should_skip_evaluation(self, dataloaders, max_batches):
75-
# skip when dataloaders aren't defined
76-
if dataloaders is None:
77-
return True
78-
79-
# enable disabling validation step with limit_val_batches = 0
80-
should_skip = sum(max_batches) == 0
81-
if should_skip:
82-
return True
83-
84-
return False
74+
def should_skip_evaluation(self, max_batches):
75+
return sum(max_batches) == 0
8576

8677
def on_evaluation_start(self, *args, **kwargs):
8778
if self.trainer.testing:

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,6 @@ def train(self):
563563
if self.max_steps and self.max_steps <= self.global_step:
564564
return
565565

566-
# update LR schedulers
567-
self.optimizer_connector.update_learning_rates(interval='epoch')
568-
569566
# early stopping
570567
met_min_epochs = epoch >= self.min_epochs - 1
571568
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
@@ -591,7 +588,7 @@ def train(self):
591588
# hook
592589
self.train_loop.on_train_end()
593590

594-
def run_evaluation(self, max_batches=None):
591+
def run_evaluation(self, max_batches=None, on_epoch=False):
595592

596593
# used to know if we are logging for val, test + reset cached results
597594
self.logger_connector.set_stage(self.testing, reset=True)
@@ -603,7 +600,7 @@ def run_evaluation(self, max_batches=None):
603600
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches)
604601

605602
# check if we want to skip this evaluation
606-
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
603+
if self.evaluation_loop.should_skip_evaluation(max_batches):
607604
return [], []
608605

609606
# ref model
@@ -664,6 +661,10 @@ def run_evaluation(self, max_batches=None):
664661
# hook
665662
self.evaluation_loop.on_evaluation_epoch_end()
666663

664+
# update epoch-level lr_schedulers
665+
if on_epoch:
666+
self.optimizer_connector.update_learning_rates(interval='epoch')
667+
667668
# hook
668669
self.evaluation_loop.on_evaluation_end()
669670

pytorch_lightning/trainer/training_loop.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.distributed as torch_distrib
2020

21-
from pytorch_lightning.callbacks import ModelCheckpoint
21+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
2222
from pytorch_lightning.core.lightning import LightningModule
2323
from pytorch_lightning.core.memory import ModelSummary
2424
from pytorch_lightning.core.optimizer import LightningOptimizer
@@ -153,7 +153,7 @@ def on_train_end(self):
153153
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
154154
# when a checkpoint was saved at the last step
155155
self.trainer.global_step -= 1
156-
self.check_checkpoint_callback(should_save=True, is_last=True)
156+
self.check_checkpoint_callback(should_update=True, is_last=True)
157157
self.trainer.global_step += 1
158158

159159
# hook
@@ -176,18 +176,27 @@ def on_train_end(self):
176176
model.cpu()
177177
torch.cuda.empty_cache()
178178

179-
def check_checkpoint_callback(self, should_save, is_last=False):
180-
# TODO bake this logic into the checkpoint callback
181-
if should_save and self.trainer.checkpoint_connector.has_trained:
182-
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
179+
def check_checkpoint_callback(self, should_update, is_last=False):
180+
# TODO bake this logic into the ModelCheckpoint callback
181+
if should_update and self.trainer.checkpoint_connector.has_trained:
182+
callbacks = self.trainer.checkpoint_callbacks
183183

184-
if is_last and any(c.save_last for c in checkpoint_callbacks):
184+
if is_last and any(cb.save_last for cb in callbacks):
185185
rank_zero_info("Saving latest checkpoint...")
186186

187187
model = self.trainer.get_model()
188188

189-
for callback in checkpoint_callbacks:
190-
callback.on_validation_end(self.trainer, model)
189+
for cb in callbacks:
190+
cb.on_validation_end(self.trainer, model)
191+
192+
def check_early_stopping_callback(self, should_update):
193+
# TODO bake this logic into the EarlyStopping callback
194+
if should_update and self.trainer.checkpoint_connector.has_trained:
195+
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
196+
model = self.trainer.get_model()
197+
198+
for cb in callbacks:
199+
cb.on_validation_end(self.trainer, model)
191200

192201
def on_train_epoch_start(self, epoch):
193202

@@ -518,7 +527,6 @@ def tbptt_split_batch(self, batch):
518527
return splits
519528

520529
def run_training_epoch(self):
521-
522530
# get model
523531
model = self.trainer.get_model()
524532

@@ -531,7 +539,6 @@ def run_training_epoch(self):
531539
# enable profiling for the dataloader
532540
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
533541
dataloader_idx = 0
534-
should_check_val = False
535542
for batch_idx, (batch, is_last_batch) in train_dataloader:
536543

537544
self.trainer.batch_idx = batch_idx
@@ -580,11 +587,12 @@ def run_training_epoch(self):
580587
self.trainer.checkpoint_connector.has_trained = True
581588

582589
# max steps reached, end training
583-
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
584-
accumulation_done = self._accumulated_batches_reached()
585-
# Ensure accumulation across batches has completed before breaking loop
586-
if accumulation_done:
587-
break
590+
if (
591+
self.trainer.max_steps is not None
592+
and self.trainer.max_steps == self.trainer.global_step + 1
593+
and self._accumulated_batches_reached()
594+
):
595+
break
588596

589597
# end epoch early
590598
# stop when the flag is changed or we've gone past the amount
@@ -595,7 +603,7 @@ def run_training_epoch(self):
595603
self.trainer.total_batch_idx += 1
596604

597605
# stop epoch if we limited the number of training batches
598-
if (batch_idx + 1) >= self.trainer.num_training_batches:
606+
if self._num_training_batches_reached(is_last_batch):
599607
break
600608

601609
# progress global step according to grads progress
@@ -612,8 +620,20 @@ def run_training_epoch(self):
612620
self.num_optimizers
613621
)
614622

615-
# when no val loop is present or fast-dev-run still need to call checkpoints
616-
self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model)))
623+
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
624+
if should_check_val:
625+
self.trainer.run_evaluation(on_epoch=True)
626+
# reset stage to train
627+
self.trainer.logger_connector.set_stage("train")
628+
629+
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
630+
should_train_only = self.trainer.disable_validation or should_skip_eval
631+
632+
if should_train_only:
633+
# update epoch level lr_schedulers
634+
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
635+
self.check_checkpoint_callback(True)
636+
self.check_early_stopping_callback(True)
617637

618638
# increment the global step once
619639
# progress global step according to grads progress
@@ -853,25 +873,33 @@ def increment_accumulated_grad_global_step(self):
853873
def _accumulated_batches_reached(self):
854874
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
855875

856-
def _num_training_batches_reached(self):
857-
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
876+
def _num_training_batches_reached(self, is_last_batch=False):
877+
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch
858878

859879
def should_accumulate(self):
860880
# checks if backward or backward + optimizer step (via closure)
861881
accumulation_done = self._accumulated_batches_reached()
862882
is_final_batch = self._num_training_batches_reached()
863883
return not (accumulation_done or is_final_batch)
864884

865-
def should_check_val_fx(self, batch_idx, is_last_batch):
885+
def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
866886
# decide if we should run validation
867887
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
868888
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
869889
can_check_val = self.trainer.enable_validation and is_val_check_epoch
870-
should_check_val = is_val_check_batch or self.trainer.should_stop
871890
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
872-
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)
891+
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches
892+
893+
should_check_val = (
894+
(is_val_check_batch and epoch_end_val_check)
895+
or self.trainer.should_stop
896+
or is_last_batch_for_infinite_dataset
897+
) if on_epoch else (
898+
is_val_check_batch
899+
and not epoch_end_val_check
900+
)
873901

874-
return should_check_val
902+
return should_check_val and can_check_val
875903

876904
def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
877905
# enable not needing to add opt_idx to training_step

tests/callbacks/test_callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,15 @@ def test_trainer_callback_system(torch_save):
8686
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
8787
call.on_batch_end(trainer, model),
8888
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
89+
call.on_epoch_end(trainer, model),
90+
call.on_train_epoch_end(trainer, model, ANY),
8991
call.on_validation_start(trainer, model),
9092
call.on_validation_epoch_start(trainer, model),
9193
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
9294
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
9395
call.on_validation_epoch_end(trainer, model),
9496
call.on_validation_end(trainer, model),
9597
call.on_save_checkpoint(trainer, model),
96-
call.on_epoch_end(trainer, model),
97-
call.on_train_epoch_end(trainer, model, ANY),
9898
call.on_train_end(trainer, model),
9999
call.on_fit_end(trainer, model),
100100
call.teardown(trainer, model, 'fit'),

tests/callbacks/test_early_stopping.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,9 @@ def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_ep
113113

114114
class ModelOverrideValidationReturn(EvalModelTemplate):
115115
validation_return_values = torch.Tensor(loss_values)
116-
count = 0
117116

118117
def validation_epoch_end(self, outputs):
119-
loss = self.validation_return_values[self.count]
120-
self.count += 1
118+
loss = self.validation_return_values[self.current_epoch]
121119
return {"test_val_loss": loss}
122120

123121
model = ModelOverrideValidationReturn()
@@ -133,6 +131,41 @@ def validation_epoch_end(self, outputs):
133131
assert trainer.current_epoch == expected_stop_epoch
134132

135133

134+
@pytest.mark.parametrize('validation_step', ['base', None])
135+
@pytest.mark.parametrize(
136+
"loss_values, patience, expected_stop_epoch",
137+
[
138+
([6, 5, 5, 5, 5, 5], 3, 4),
139+
([6, 5, 4, 4, 3, 3], 1, 3),
140+
([6, 5, 6, 5, 5, 5], 3, 4),
141+
],
142+
)
143+
def test_early_stopping_patience_train(tmpdir, validation_step, loss_values, patience, expected_stop_epoch):
144+
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
145+
146+
class ModelOverrideTrainReturn(EvalModelTemplate):
147+
train_return_values = torch.Tensor(loss_values)
148+
149+
def training_epoch_end(self, outputs):
150+
loss = self.train_return_values[self.current_epoch]
151+
self.log('train_loss', loss)
152+
153+
model = ModelOverrideTrainReturn()
154+
155+
if validation_step is None:
156+
model.validation_step = None
157+
158+
early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True)
159+
trainer = Trainer(
160+
default_root_dir=tmpdir,
161+
callbacks=[early_stop_callback],
162+
num_sanity_val_steps=0,
163+
max_epochs=10,
164+
)
165+
trainer.fit(model)
166+
assert trainer.current_epoch == expected_stop_epoch
167+
168+
136169
def test_pickling(tmpdir):
137170
early_stopping = EarlyStopping()
138171

0 commit comments

Comments
 (0)