From 4223acc8445f62d7cb33ef527ed395d1ea9fc4f8 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 27 May 2025 21:37:14 -0500 Subject: [PATCH 1/4] Set multiprocessing to spawn for Linux --- pymc/sampling/mcmc.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f2dfa6e9c2..fae0107f9e 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -17,6 +17,7 @@ import contextlib import logging import pickle +import platform import sys import time import warnings @@ -78,6 +79,11 @@ ) from pymc.vartypes import discrete_types +if platform.system() == "linux": + import multiprocessing + + multiprocessing.set_start_method("spawn", force=True) + try: from zarr.storage import MemoryStore except ImportError: From 59b3ad35ac1aa7670cfd4a58f318cba4d392eea5 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 27 May 2025 21:45:21 -0500 Subject: [PATCH 2/4] Added code comment --- pymc/sampling/mcmc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index fae0107f9e..c99f89003c 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -80,6 +80,7 @@ from pymc.vartypes import discrete_types if platform.system() == "linux": + # Threads are not fork-safe on Linux, so we need to use spawn import multiprocessing multiprocessing.set_start_method("spawn", force=True) From ff251340152c2d7440cb2b5a64211a2e320bf047 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 27 May 2025 22:10:36 -0500 Subject: [PATCH 3/4] Moved linux mp setting to parallel.py with the other platform settings --- pymc/sampling/mcmc.py | 7 ------- pymc/sampling/parallel.py | 22 +++++++++++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index c99f89003c..f2dfa6e9c2 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -17,7 +17,6 @@ import contextlib import logging import pickle -import platform import sys import time import warnings @@ -79,12 +78,6 @@ ) from pymc.vartypes import discrete_types -if platform.system() == "linux": - # Threads are not fork-safe on Linux, so we need to use spawn - import multiprocessing - - multiprocessing.set_start_method("spawn", force=True) - try: from zarr.storage import MemoryStore except ImportError: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index af2106ce6f..bac2103404 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -437,15 +437,23 @@ def __init__( if mp_ctx is None or isinstance(mp_ctx, str): # Closes issue https://github.com/pymc-devs/pymc/issues/3849 # Related issue https://github.com/pymc-devs/pymc/issues/5339 - if mp_ctx is None and platform.system() == "Darwin": - if platform.processor() == "arm": - mp_ctx = "fork" + if mp_ctx is None: + if platform.system() == "Darwin": + if platform.processor() == "arm": + mp_ctx = "fork" + logger.debug( + "mp_ctx is set to 'fork' for MacOS with ARM architecture. " + + "This might cause unexpected behavior with JAX, which is inherently multithreaded." + ) + else: + mp_ctx = "forkserver" + elif platform.system() == "Linux": + # Threads are not fork-safe on Linux when using libraries like OpenBLAS + mp_ctx = "spawn" logger.debug( - "mp_ctx is set to 'fork' for MacOS with ARM architecture. " - + "This might cause unexpected behavior with JAX, which is inherently multithreaded." + "mp_ctx is set to 'spawn' for Linux to ensure thread safety. " + + "This is required when using multithreaded numerical libraries." ) - else: - mp_ctx = "forkserver" mp_ctx = multiprocessing.get_context(mp_ctx) From 8b6f9e67becf768354f73abc85c704c24fecad5f Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 27 May 2025 22:13:45 -0500 Subject: [PATCH 4/4] updated code comment --- pymc/sampling/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index bac2103404..292bb4d982 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -448,7 +448,7 @@ def __init__( else: mp_ctx = "forkserver" elif platform.system() == "Linux": - # Threads are not fork-safe on Linux when using libraries like OpenBLAS + # Threads are not fork-safe on Linux mp_ctx = "spawn" logger.debug( "mp_ctx is set to 'spawn' for Linux to ensure thread safety. "