Skip to content

Commit a26424e

Browse files
awaelchliBorda
andauthored
Fix zero-grad behavior when entering the validation loop (#18710)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent 7fd5c02 commit a26424e

File tree

10 files changed

+61
-10
lines changed

10 files changed

+61
-10
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
148148
- Added support for returning an object of type `Mapping` from `LightningModule.training_step()` ([#18657](https://github.com/Lightning-AI/lightning/pull/18657))
149149

150150

151+
- Added the hook `LightningModule.on_validation_model_zero_grad()` to allow overriding the behavior of zeroing the gradients before entering the validation loop ([#18710](https://github.com/Lightning-AI/lightning/pull/18710))
152+
153+
151154
### Changed
152155

153156
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
@@ -289,6 +292,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
289292
- Fixed numerical issues when reducing values in low precision with `self.log` ([#18686](https://github.com/Lightning-AI/lightning/pull/18686))
290293

291294

295+
- Fixed an issue that would cause the gradients to be erased if validation happened in the middle of a gradient accumulation phase ([#18710](https://github.com/Lightning-AI/lightning/pull/18710))
296+
292297

293298
## [2.0.9] - 2023-09-14
294299

src/lightning/pytorch/core/hooks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch import Tensor
2020
from torch.optim.optimizer import Optimizer
2121

22+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
2223
from lightning.pytorch.utilities import move_data_to_device
2324
from lightning.pytorch.utilities.exceptions import MisconfigurationException
2425
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
@@ -151,6 +152,11 @@ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: in
151152
152153
"""
153154

155+
def on_validation_model_zero_grad(self) -> None:
156+
"""Called by the training loop to release gradients before entering the validation loop."""
157+
zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True}
158+
self.zero_grad(**zero_grad_kwargs)
159+
154160
def on_validation_model_eval(self) -> None:
155161
"""Sets the model to eval during the val loop."""
156162
self.trainer.model.eval()

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,7 @@ def on_run_start(self) -> None:
239239
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
240240
hooks."""
241241
self._verify_dataloader_idx_requirement()
242-
243242
self._on_evaluation_model_eval()
244-
self.trainer.lightning_module.zero_grad()
245243
self._on_evaluation_start()
246244
self._on_evaluation_epoch_start()
247245

src/lightning/pytorch/loops/prediction_loop.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,7 @@ def reset(self) -> None:
191191
def on_run_start(self) -> None:
192192
"""Calls ``_on_predict_model_eval``, ``_on_predict_start`` and ``_on_predict_epoch_start`` hooks."""
193193
self._verify_dataloader_idx_requirement()
194-
195-
trainer = self.trainer
196-
call._call_lightning_module_hook(trainer, "on_predict_model_eval")
197-
trainer.lightning_module.zero_grad()
194+
call._call_lightning_module_hook(self.trainer, "on_predict_model_eval")
198195
self._on_predict_start()
199196
self._on_predict_epoch_start()
200197

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,11 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
277277
self.trainer.validating = True
278278
# save and reset this state in case validation runs inside training loop (val_check_interval<1.0)
279279
first_loop_iter = self.trainer._logger_connector._first_loop_iter
280+
281+
if not self._should_accumulate():
282+
# clear gradients to not leave any unused memory during validation
283+
call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad")
284+
280285
self.val_loop.run()
281286
self.trainer.training = True
282287
self.trainer._logger_connector._first_loop_iter = first_loop_iter

src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class _LogOptions(TypedDict):
141141
"test_dataloader": None,
142142
"prepare_data": None,
143143
"configure_callbacks": None,
144+
"on_validation_model_zero_grad": None,
144145
"on_validation_model_eval": None,
145146
"on_test_model_eval": None,
146147
"on_validation_model_train": None,

src/lightning/pytorch/trainer/trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,9 @@ def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
10221022
# wait for all to join if on distributed
10231023
self.strategy.barrier("run-stage")
10241024

1025+
zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True}
1026+
self.lightning_module.zero_grad(**zero_grad_kwargs)
1027+
10251028
if self.evaluating:
10261029
return self._evaluation_loop.run()
10271030
if self.predicting:

tests/tests_pytorch/loops/test_loops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,3 +851,34 @@ def _get_iterator(self):
851851
3, # teardown on epoch 2, workers from epoch 2 get destroyed
852852
]
853853
assert val_dataloader.shutdown_workers_epochs == expected
854+
855+
856+
def test_validation_during_gradient_accumulation_window(tmp_path):
857+
"""Test that gradients don't get erased when the validation interval falls within the gradient accumulation
858+
phase."""
859+
860+
class ValidationModel(BoringModel):
861+
def on_validation_start(self):
862+
batch_idx = self.trainer.fit_loop.epoch_loop.batch_progress.current.completed
863+
grad_expected = batch_idx % self.trainer.accumulate_grad_batches != 0
864+
if grad_expected:
865+
assert batch_idx in (2, 4)
866+
assert all(p.grad is not None for p in self.parameters())
867+
else:
868+
assert batch_idx == 6
869+
assert all(p.grad is None for p in self.parameters())
870+
self.ran_assert = True
871+
872+
model = ValidationModel()
873+
trainer = Trainer(
874+
default_root_dir=tmp_path,
875+
limit_train_batches=6,
876+
limit_val_batches=1,
877+
accumulate_grad_batches=3,
878+
# validation happens in the middle of the first two accumulations, and at the end of the third
879+
val_check_interval=2,
880+
max_epochs=1,
881+
num_sanity_val_steps=0,
882+
)
883+
trainer.fit(model)
884+
assert model.ran_assert

tests/tests_pytorch/models/test_hooks.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import pytest
2020
import torch
21+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
2122
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__
2223
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
2324
from torch import Tensor
@@ -465,11 +466,11 @@ def training_step(self, batch, batch_idx):
465466
{"name": "configure_optimizers"},
466467
{"name": "Callback.on_fit_start", "args": (trainer, model)},
467468
{"name": "on_fit_start"},
469+
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
468470
{"name": "Callback.on_sanity_check_start", "args": (trainer, model)},
469471
{"name": "val_dataloader"},
470472
{"name": "train", "args": (False,)},
471473
{"name": "on_validation_model_eval"},
472-
{"name": "zero_grad"},
473474
{"name": "Callback.on_validation_start", "args": (trainer, model)},
474475
{"name": "on_validation_start"},
475476
*model._eval_epoch("validation", trainer, model, val_batches, "x", device=device),
@@ -486,9 +487,10 @@ def training_step(self, batch, batch_idx):
486487
{"name": "Callback.on_train_epoch_start", "args": (trainer, model)},
487488
{"name": "on_train_epoch_start"},
488489
*model._train_batch(trainer, model, train_batches, device=device, **kwargs),
490+
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
491+
{"name": "on_validation_model_zero_grad"},
489492
{"name": "train", "args": (False,)},
490493
{"name": "on_validation_model_eval"},
491-
{"name": "zero_grad"},
492494
{"name": "Callback.on_validation_start", "args": (trainer, model)},
493495
{"name": "on_validation_start"},
494496
*model._eval_epoch("validation", trainer, model, val_batches, "x", device=device),
@@ -566,6 +568,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
566568
{"name": "configure_optimizers"},
567569
{"name": "Callback.on_fit_start", "args": (trainer, model)},
568570
{"name": "on_fit_start"},
571+
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
569572
{"name": "train_dataloader"},
570573
{"name": "train", "args": (True,)},
571574
{"name": "Callback.on_train_start", "args": (trainer, model)},
@@ -644,6 +647,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
644647
{"name": "configure_optimizers"},
645648
{"name": "Callback.on_fit_start", "args": (trainer, model)},
646649
{"name": "on_fit_start"},
650+
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
647651
{"name": "train_dataloader"},
648652
{"name": "train", "args": (True,)},
649653
{"name": "Callback.on_train_start", "args": (trainer, model)},
@@ -690,7 +694,6 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
690694
{"name": f"{dataloader}_dataloader"},
691695
{"name": "train", "args": (False,)},
692696
{"name": f"on_{noun}_model_eval"},
693-
{"name": "zero_grad"},
694697
{"name": f"Callback.on_{noun}_start", "args": (trainer, model)},
695698
{"name": f"on_{noun}_start"},
696699
*model._eval_epoch(noun, trainer, model, batches, key, trainer.strategy.root_device),
@@ -705,6 +708,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
705708
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": verb}},
706709
{"name": "setup", "kwargs": {"stage": verb}},
707710
{"name": "configure_model"},
711+
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
708712
*(hooks if batches else []),
709713
{"name": "Callback.teardown", "args": (trainer, model), "kwargs": {"stage": verb}},
710714
{"name": "teardown", "kwargs": {"stage": verb}},
@@ -727,10 +731,10 @@ def test_trainer_model_hook_system_predict(tmpdir):
727731
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "predict"}},
728732
{"name": "setup", "kwargs": {"stage": "predict"}},
729733
{"name": "configure_model"},
734+
{"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
730735
{"name": "predict_dataloader"},
731736
{"name": "train", "args": (False,)},
732737
{"name": "on_predict_model_eval"},
733-
{"name": "zero_grad"},
734738
{"name": "Callback.on_predict_start", "args": (trainer, model)},
735739
{"name": "on_predict_start"},
736740
{"name": "Callback.on_predict_epoch_start", "args": (trainer, model)},

tests/tests_pytorch/trainer/logging_/test_logger_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def test_fx_validator_integration(tmpdir):
211211
"on_sanity_check_end": "You can't",
212212
"prepare_data": "You can't",
213213
"configure_callbacks": "You can't",
214+
"on_validation_model_zero_grad": "You can't",
214215
"on_validation_model_eval": "You can't",
215216
"on_validation_model_train": "You can't",
216217
"lr_scheduler_step": "You can't",

0 commit comments

Comments
 (0)