Skip to content

Commit b497fb8

Browse files
authored
Remove reference to DistributedDataParallel from parallel plugin teardown (#8943)
1 parent 53885af commit b497fb8

File tree

6 files changed

+37
-12
lines changed

6 files changed

+37
-12
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
217217
- Removed `Plugin` in `base_plugin.py`, access `TrainingTypePlugin` and `PrecisionPlugin` directly instead ([#9066](https://github.com/PyTorchLightning/pytorch-lightning/pull/9066))
218218

219219

220+
- Removed `teardown` from `ParallelPlugin` ([#8943](https://github.com/PyTorchLightning/pytorch-lightning/pull/8943))
221+
222+
220223
### Fixed
221224

222225
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,13 @@ def reconciliate_processes(self, trace: str):
501501
os.kill(pid, signal.SIGKILL)
502502
shutil.rmtree(sync_dir)
503503
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
504+
505+
def teardown(self) -> None:
506+
if isinstance(self.model, DistributedDataParallel):
507+
self.model = self.lightning_module
508+
509+
if self.on_gpu:
510+
# GPU teardown
511+
self.lightning_module.cpu()
512+
# clean up memory
513+
torch.cuda.empty_cache()

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,13 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
364364
description="DDPSpawn Plugin with `find_unused_parameters` as False",
365365
find_unused_parameters=False,
366366
)
367+
368+
def teardown(self) -> None:
369+
if isinstance(self.model, DistributedDataParallel):
370+
self.model = self.lightning_module
371+
372+
if self.on_gpu:
373+
# GPU teardown
374+
self.lightning_module.cpu()
375+
# clean up memory
376+
torch.cuda.empty_cache()

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,10 @@ def test_step_end(self, output):
119119
if not is_overridden("test_step_end", self.lightning_module):
120120
return self.reduce(output)
121121
return output
122+
123+
def teardown(self) -> None:
124+
if self.on_gpu:
125+
# GPU teardown
126+
self.lightning_module.cpu()
127+
# clean up memory
128+
torch.cuda.empty_cache()

pytorch_lightning/plugins/training_type/horovod.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,10 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.Distributed
206206
def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tuple[str, nn.Parameter]]:
207207
opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])}
208208
return [(name, p) for name, p in model.named_parameters() if p in opt_params]
209+
210+
def teardown(self) -> None:
211+
if self.on_gpu:
212+
# GPU teardown
213+
self.lightning_module.cpu()
214+
# clean up memory
215+
torch.cuda.empty_cache()

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,3 @@ def block_backward_sync(self):
133133
yield None
134134
else:
135135
yield None
136-
137-
def teardown(self) -> None:
138-
# Un-reference the wrapper if any was used.
139-
# todo (tchaton): Add support for all plugins.
140-
if isinstance(self.model, DistributedDataParallel):
141-
self.model = self.lightning_module
142-
143-
if self.on_gpu:
144-
# GPU teardown
145-
self.lightning_module.cpu()
146-
# clean up memory
147-
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)