Skip to content

Commit 1a6786d

Browse files
authored
Destroy process group in atexit handler (#19931)
1 parent b9f215d commit 1a6786d

File tree

6 files changed

+29
-6
lines changed

6 files changed

+29
-6
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717

1818
- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870), [#19872](https://github.com/Lightning-AI/pytorch-lightning/pull/19872))
1919

20+
- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))
21+
2022

2123
### Changed
2224

src/lightning/fabric/utilities/distributed.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import atexit
12
import contextlib
23
import logging
34
import os
@@ -291,6 +292,10 @@ def _init_dist_connection(
291292
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
292293
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
293294

295+
if torch_distributed_backend == "nccl":
296+
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
297+
atexit.register(_destroy_dist_connection)
298+
294299
# On rank=0 let everyone know training is starting
295300
rank_zero_info(
296301
f"{'-' * 100}\n"
@@ -300,6 +305,11 @@ def _init_dist_connection(
300305
)
301306

302307

308+
def _destroy_dist_connection() -> None:
309+
if _distributed_is_initialized():
310+
torch.distributed.destroy_process_group()
311+
312+
303313
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
304314
return "nccl" if device.type == "cuda" else "gloo"
305315

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020

2121
- Added `ModelParallelStrategy` to support 2D parallelism ([#19878](https://github.com/Lightning-AI/pytorch-lightning/pull/19878), [#19888](https://github.com/Lightning-AI/pytorch-lightning/pull/19888))
2222

23+
- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))
2324

2425

2526
### Changed

tests/tests_fabric/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch.distributed
2424
from lightning.fabric.accelerators import XLAAccelerator
2525
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
26-
from lightning.fabric.utilities.distributed import _distributed_is_initialized
26+
from lightning.fabric.utilities.distributed import _destroy_dist_connection
2727

2828
if sys.version_info >= (3, 9):
2929
from concurrent.futures.process import _ExecutorManagerThread
@@ -78,8 +78,7 @@ def restore_env_variables():
7878
def teardown_process_group():
7979
"""Ensures that the distributed process group gets closed before the next test runs."""
8080
yield
81-
if _distributed_is_initialized():
82-
torch.distributed.destroy_process_group()
81+
_destroy_dist_connection()
8382

8483

8584
@pytest.fixture(autouse=True)

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
1212
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
1313
from lightning.fabric.utilities.distributed import (
14+
_destroy_dist_connection,
1415
_gather_all_tensors,
1516
_InfiniteBarrier,
17+
_init_dist_connection,
1618
_set_num_threads_if_needed,
1719
_suggested_max_num_threads,
1820
_sync_ddp,
@@ -217,3 +219,13 @@ def test_infinite_barrier():
217219
barrier.__exit__(None, None, None)
218220
assert barrier.barrier.call_count == 2
219221
dist_mock.destroy_process_group.assert_called_once()
222+
223+
224+
@mock.patch("lightning.fabric.utilities.distributed.atexit")
225+
@mock.patch("lightning.fabric.utilities.distributed.torch.distributed.init_process_group")
226+
def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
227+
_init_dist_connection(LightningEnvironment(), "nccl")
228+
atexit_mock.register.assert_called_once_with(_destroy_dist_connection)
229+
atexit_mock.reset_mock()
230+
_init_dist_connection(LightningEnvironment(), "gloo")
231+
atexit_mock.register.assert_not_called()

tests/tests_pytorch/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch.distributed
2828
from lightning.fabric.plugins.environments.lightning import find_free_network_port
2929
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
30-
from lightning.fabric.utilities.distributed import _distributed_is_initialized
30+
from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized
3131
from lightning.fabric.utilities.imports import _IS_WINDOWS
3232
from lightning.pytorch.accelerators import XLAAccelerator
3333
from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector
@@ -123,8 +123,7 @@ def restore_signal_handlers():
123123
def teardown_process_group():
124124
"""Ensures that the distributed process group gets closed before the next test runs."""
125125
yield
126-
if _distributed_is_initialized():
127-
torch.distributed.destroy_process_group()
126+
_destroy_dist_connection()
128127

129128

130129
@pytest.fixture(autouse=True)

0 commit comments

Comments
 (0)