Skip to content

Commit f2c5f5b

Browse files
justusschocktchatoncarmoccaawaelchli
authored andcommitted
Clear reference to training loss at the end of train step (#9336)
Without clearing this reference, the loss tensor stays live through the next training step. This can be a problem for memory intensive models that produce very deep backward graphs such as neural ODEs. For these models, keeping the backward graph of the previous loss in memory can lead to OOM errors in the next training step even though the step might have succeeded if we had cleared (and thus GC'd) the previous backward graph. Co-authored-by: tchaton <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent a5ad966 commit f2c5f5b

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,18 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT:
201201
"""
202202

203203
def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
204-
self.trainer._results.extract_batch_size(split_batch)
204+
assert self.trainer._results is not None
205+
# when the user requests `dataloader_iter`, we can't track the batch_size
206+
# and this is left to user responsibility.
207+
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
208+
self.trainer._results.extract_batch_size(split_batch)
209+
205210
self._batch_idx = batch_idx
206211
self._split_idx = split_idx
207212

213+
# clear reference to this step's training loss so that it can be garbage collected before the next training step
214+
self.trainer._results.minimize = None
215+
208216
def update_train_step_metrics(self) -> None:
209217
if self.trainer.fit_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
210218
return

tests/trainer/loops/test_training_loop.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,16 @@ def training_epoch_end(self, outputs) -> None:
190190
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
191191
trainer.fit(model)
192192
assert model.on_train_batch_end_called == 2
193+
194+
195+
def test_batch_loop_releases_loss(tmpdir):
196+
"""Test that loss/graph is released so that it can be garbage collected before the next training step"""
197+
198+
class TestModel(BoringModel):
199+
def training_step(self, batch, batch_idx):
200+
assert self.trainer._results.minimize is None
201+
return super().training_step(batch, batch_idx)
202+
203+
model = TestModel()
204+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
205+
trainer.fit(model)

0 commit comments

Comments
 (0)