Skip to content

Commit 1c20b38

Browse files
committed
Fix ddp_notebook CUDA fork check to allow passive initialization
The previous implementation used torch.cuda.is_initialized() which returns True even when CUDA is passively initialized (e.g., during library imports or device availability checks). This caused false positives in environments like Kaggle notebooks where libraries may query CUDA without creating a context. This fix uses PyTorch's internal torch.cuda._is_in_bad_fork() function, which more accurately detects when we're in an actual bad fork state (i.e., CUDA was initialized with a context and then the process was forked). The change allows passive CUDA initialization while still catching genuine problematic cases. Falls back to the old check for older PyTorch versions that don't have _is_in_bad_fork. Fixes #21389
1 parent 79ffe50 commit 1c20b38

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

src/lightning/fabric/strategies/launchers/multiprocessing.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,29 @@ def _check_bad_cuda_fork() -> None:
195195
Lightning users.
196196
197197
"""
198-
if not torch.cuda.is_initialized():
199-
return
200-
201-
message = (
202-
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
203-
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
204-
" other way? Please remove any such calls, or change the selected strategy."
205-
)
206-
if _IS_INTERACTIVE:
207-
message += " You will have to restart the Python kernel."
208-
raise RuntimeError(message)
198+
# Use PyTorch's internal check for bad fork state, which is more accurate than just checking if CUDA
199+
# is initialized. This allows passive CUDA initialization (e.g., from library imports or device queries)
200+
# while still catching actual problematic cases where CUDA context was created before forking.
201+
_is_in_bad_fork = getattr(torch.cuda, "_is_in_bad_fork", None)
202+
if _is_in_bad_fork is not None and _is_in_bad_fork():
203+
message = (
204+
"Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, "
205+
"you must use the 'spawn' start method or avoid CUDA initialization in the main process."
206+
)
207+
if _IS_INTERACTIVE:
208+
message += " You will have to restart the Python kernel."
209+
raise RuntimeError(message)
210+
211+
# Fallback to the old check if _is_in_bad_fork is not available (older PyTorch versions)
212+
if _is_in_bad_fork is None and torch.cuda.is_initialized():
213+
message = (
214+
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
215+
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
216+
" other way? Please remove any such calls, or change the selected strategy."
217+
)
218+
if _IS_INTERACTIVE:
219+
message += " You will have to restart the Python kernel."
220+
raise RuntimeError(message)
209221

210222

211223
def _disable_module_memory_sharing(data: Any) -> Any:

0 commit comments

Comments
 (0)