Skip to content

Commit 33941cf

Browse files
awaelchlijustusschockBorda
authored andcommitted
Fix device parser logic to avoid creating CUDA context (#14319)
* let environment disable forking * add helper function and error messages * tests * changelog Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 919ce81 commit 33941cf

File tree

6 files changed

+40
-3
lines changed

6 files changed

+40
-3
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
## [1.7.4] - 2022-08-30
88

9+
### Added
10+
11+
- Added an environment variable `PL_DISABLE_FORK` that can be used to disable all forking in the Trainer ([#14319](https://github.com/Lightning-AI/lightning/issues/14319))
12+
913
### Fixed
1014

1115
- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806))
1216
- Reset epoch progress with batch size scaler ([#13846](https://github.com/Lightning-AI/lightning/pull/13846)
1317
- Fixed restoring the trainer after using `lr_find()` so that the correct LR schedule is used for the actual training ([#14113](https://github.com/Lightning-AI/lightning/pull/14113))
1418

1519

16-
1720
## [1.7.3] - 2022-08-25
1821

1922
### Fixed

src/pytorch_lightning/strategies/launchers/multiprocessing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def __init__(self, strategy: Strategy, start_method: Literal["spawn", "fork", "f
6666
f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
6767
f" {', '.join(mp.get_all_start_methods())}"
6868
)
69+
if start_method in ("fork", "forkserver") and _is_forking_disabled():
70+
raise ValueError(
71+
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method."
72+
)
6973

7074
@property
7175
def is_interactive_compatible(self) -> bool:
@@ -270,3 +274,8 @@ def restore(self) -> None:
270274
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
271275
torch.backends.cudnn.benchmark = self.cudnn_benchmark
272276
_set_rng_states(self.rng_states)
277+
278+
279+
def _is_forking_disabled() -> bool:
280+
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
281+
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
TPUSpawnStrategy,
7575
)
7676
from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES
77+
from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled
7778
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
7879
from pytorch_lightning.utilities import (
7980
_StrategyType,
@@ -637,6 +638,10 @@ def _check_strategy_and_fallback(self) -> None:
637638
f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this"
638639
f" platform. We recommed `Trainer(strategy='ddp_spawn')` instead."
639640
)
641+
if strategy_flag in _DDP_FORK_ALIASES and _is_forking_disabled():
642+
raise ValueError(
643+
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different strategy."
644+
)
640645
if strategy_flag:
641646
self._strategy_flag = strategy_flag
642647

src/pytorch_lightning/utilities/device_parser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.cuda
1919

2020
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
21+
from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled
2122
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
2223
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2324
from pytorch_lightning.utilities.types import _DEVICE
@@ -340,7 +341,7 @@ def num_cuda_devices() -> int:
340341
Unlike :func:`torch.cuda.device_count`, this function will do its best not to create a CUDA context for fork
341342
support, if the platform allows it.
342343
"""
343-
if "fork" not in torch.multiprocessing.get_all_start_methods():
344+
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
344345
return torch.cuda.device_count()
345346
with multiprocessing.get_context("fork").Pool(1) as pool:
346347
return pool.apply(torch.cuda.device_count)
@@ -352,7 +353,7 @@ def is_cuda_available() -> bool:
352353
Unlike :func:`torch.cuda.is_available`, this function will do its best not to create a CUDA context for fork
353354
support, if the platform allows it.
354355
"""
355-
if "fork" not in torch.multiprocessing.get_all_start_methods():
356+
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
356357
return torch.cuda.is_available()
357358
with multiprocessing.get_context("fork").Pool(1) as pool:
358359
return pool.apply(torch.cuda.is_available)

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
from unittest import mock
1516
from unittest.mock import ANY, Mock
1617

1718
import pytest
1819
import torch
1920

2021
from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
22+
from tests_pytorch.helpers.runif import RunIf
2123

2224

2325
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
@@ -26,6 +28,14 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
2628
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
2729

2830

31+
@RunIf(skip_windows=True)
32+
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
33+
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
34+
def test_multiprocessing_launcher_disabled_forking(start_method):
35+
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
36+
_MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
37+
38+
2939
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
3040
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
3141
def test_multiprocessing_launcher_start_method(mp_mock, start_method):

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,3 +808,12 @@ def test_accelerator_specific_checkpoint_io(*_):
808808
def test_ddp_fork_on_unsupported_platform(_, strategy):
809809
with pytest.raises(ValueError, match="process forking is not supported on this platform"):
810810
Trainer(strategy=strategy)
811+
812+
813+
@RunIf(skip_windows=True)
814+
@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES)
815+
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
816+
def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy):
817+
"""Test there is an error when forking is disabled via the environment variable and the user requests fork."""
818+
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
819+
Trainer(devices=2, strategy=strategy)

0 commit comments

Comments
 (0)