Skip to content

Commit fcc7b00

Browse files
authored
Improve error message when "if name == main" guard is needed (#18298)
1 parent 6df4368 commit fcc7b00

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

src/lightning/fabric/strategies/launchers/multiprocessing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
from dataclasses import dataclass
1717
from multiprocessing.queues import SimpleQueue
18+
from textwrap import dedent
1819
from typing import Any, Callable, Dict, Literal, Optional, TYPE_CHECKING
1920

2021
import torch
@@ -91,6 +92,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
9192
"""
9293
if self._start_method in ("fork", "forkserver"):
9394
_check_bad_cuda_fork()
95+
if self._start_method == "spawn":
96+
_check_missing_main_guard()
9497

9598
# The default cluster environment in Lightning chooses a random free port number
9699
# This needs to be done in the main process here before starting processes to ensure each rank will connect
@@ -216,3 +219,25 @@ def unshare(module: Module) -> Module:
216219
return module
217220

218221
return apply_to_collection(data, function=unshare, dtype=Module)
222+
223+
224+
def _check_missing_main_guard() -> None:
225+
"""Raises an exception if the ``__name__ == "__main__"`` guard is missing."""
226+
if not getattr(mp.current_process(), "_inheriting", False):
227+
return
228+
message = dedent(
229+
"""
230+
Launching multiple processes with the 'spawn' start method requires that your script guards the main
231+
function with an `if __name__ == \"__main__\"` clause. For example:
232+
233+
def main():
234+
# Put your code here
235+
...
236+
237+
if __name__ == "__main__":
238+
main()
239+
240+
Alternatively, you can run with `strategy="ddp"` to avoid this error.
241+
"""
242+
)
243+
raise RuntimeError(message)

src/lightning/pytorch/strategies/launchers/multiprocessing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
from torch import Tensor
2828

2929
import lightning.pytorch as pl
30-
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork, _disable_module_memory_sharing
30+
from lightning.fabric.strategies.launchers.multiprocessing import (
31+
_check_bad_cuda_fork,
32+
_check_missing_main_guard,
33+
_disable_module_memory_sharing,
34+
)
3135
from lightning.fabric.utilities import move_data_to_device
3236
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
3337
from lightning.fabric.utilities.types import _PATH
@@ -99,6 +103,8 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
99103
"""
100104
if self._start_method in ("fork", "forkserver"):
101105
_check_bad_cuda_fork()
106+
if self._start_method == "spawn":
107+
_check_missing_main_guard()
102108

103109
# The default cluster environment in Lightning chooses a random free port number
104110
# This needs to be done in the main process here before starting processes to ensure each rank will connect

tests/tests_fabric/strategies/launchers/test_multiprocessing.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def test_forking_on_unsupported_platform(_):
3636

3737
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
3838
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
39-
def test_start_method(mp_mock, start_method):
39+
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing._check_missing_main_guard")
40+
def test_start_method(_, mp_mock, start_method):
4041
mp_mock.get_all_start_methods.return_value = [start_method]
4142
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
4243
launcher.launch(function=Mock())
@@ -51,7 +52,8 @@ def test_start_method(mp_mock, start_method):
5152

5253
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
5354
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
54-
def test_restore_globals(mp_mock, start_method):
55+
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing._check_missing_main_guard")
56+
def test_restore_globals(_, mp_mock, start_method):
5557
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
5658
mp_mock.get_all_start_methods.return_value = [start_method]
5759
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
@@ -94,3 +96,12 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
9496
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
9597
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
9698
launcher.launch(function=Mock())
99+
100+
101+
def test_check_for_missing_main_guard():
102+
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
103+
with mock.patch(
104+
"lightning.fabric.strategies.launchers.multiprocessing.mp.current_process",
105+
return_value=Mock(_inheriting=True), # pretend that main is importing itself
106+
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
107+
launcher.launch(function=Mock())

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,12 @@ def test_memory_sharing_disabled():
203203

204204
trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0)
205205
trainer.fit(model)
206+
207+
208+
def test_check_for_missing_main_guard():
209+
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
210+
with mock.patch(
211+
"lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process",
212+
return_value=Mock(_inheriting=True), # pretend that main is importing itself
213+
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
214+
launcher.launch(function=Mock())

0 commit comments

Comments
 (0)