Skip to content

Commit f2cc224

Browse files
committed
fix: Add callable check and improve test coverage for CUDA fork check
- Add callable() check before calling _is_in_bad_fork to ensure robustness - Add test_check_for_bad_cuda_fork_with_is_in_bad_fork() to test new detection path - Ensures test coverage for both the new _is_in_bad_fork and fallback paths
1 parent fc8a8ec commit f2cc224

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/lightning/fabric/strategies/launchers/multiprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _check_bad_cuda_fork() -> None:
199199
# is initialized. This allows passive CUDA initialization (e.g., from library imports or device queries)
200200
# while still catching actual problematic cases where CUDA context was created before forking.
201201
_is_in_bad_fork = getattr(torch.cuda, "_is_in_bad_fork", None)
202-
if _is_in_bad_fork is not None and _is_in_bad_fork():
202+
if _is_in_bad_fork is not None and callable(_is_in_bad_fork) and _is_in_bad_fork():
203203
message = (
204204
"Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, "
205205
"you must use the 'spawn' start method or avoid CUDA initialization in the main process."

tests/tests_fabric/strategies/launchers/test_multiprocessing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
9898
launcher.launch(function=Mock())
9999

100100

101+
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
102+
@mock.patch("torch.cuda._is_in_bad_fork", return_value=True)
103+
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
104+
def test_check_for_bad_cuda_fork_with_is_in_bad_fork(mp_mock, _, start_method):
105+
"""Test the new _is_in_bad_fork detection when available."""
106+
mp_mock.get_all_start_methods.return_value = [start_method]
107+
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
108+
with pytest.raises(RuntimeError, match="Cannot re-initialize CUDA in forked subprocess"):
109+
launcher.launch(function=Mock())
110+
111+
101112
def test_check_for_missing_main_guard():
102113
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
103114
with (

0 commit comments

Comments
 (0)