diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index af2106ce6f..292bb4d982 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -437,15 +437,23 @@ def __init__( if mp_ctx is None or isinstance(mp_ctx, str): # Closes issue https://github.com/pymc-devs/pymc/issues/3849 # Related issue https://github.com/pymc-devs/pymc/issues/5339 - if mp_ctx is None and platform.system() == "Darwin": - if platform.processor() == "arm": - mp_ctx = "fork" + if mp_ctx is None: + if platform.system() == "Darwin": + if platform.processor() == "arm": + mp_ctx = "fork" + logger.debug( + "mp_ctx is set to 'fork' for MacOS with ARM architecture. " + + "This might cause unexpected behavior with JAX, which is inherently multithreaded." + ) + else: + mp_ctx = "forkserver" + elif platform.system() == "Linux": + # Threads are not fork-safe on Linux + mp_ctx = "spawn" logger.debug( - "mp_ctx is set to 'fork' for MacOS with ARM architecture. " - + "This might cause unexpected behavior with JAX, which is inherently multithreaded." + "mp_ctx is set to 'spawn' for Linux to ensure thread safety. " + + "This is required when using multithreaded numerical libraries." ) - else: - mp_ctx = "forkserver" mp_ctx = multiprocessing.get_context(mp_ctx)