Skip to content
Draft
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,3 +1102,41 @@ def setup(self, worker):

def teardown(self, worker):
self._exit_stack.close()


class MultiprocessingAuthkeyPlugin(WorkerPlugin):
"""
A WorkerPlugin to propagate the main process's ``multiprocessing.current_process().authkey``
to Dask workers.

This is necessary when using a ``multiprocessing.Manager`` for communication between the
main process and its workers, especially in distributed settings such as with
``dask_jobqueue.SLURMCluster``. In standard multiprocessing, the ``authkey`` is automatically
propagated to child processes, but in distributed clusters, this must be done manually.

This plugin securely forwards the ``authkey`` from the client process to all workers by
setting the environment variable ``_DASK_MULTIPROCESSING_AUTHKEY`` and updating the worker's
``multiprocessing.current_process().authkey`` accordingly.

Examples
--------
>>> from distributed.diagnostics.plugin import MultiprocessingAuthkeyPlugin
>>> client.register_plugin(MultiprocessingAuthkeyPlugin())
"""

name = "multiprocessing-authkey"
idempotent = True

def __init__(self) -> None:
import multiprocessing.process

os.environ["_DASK_MULTIPROCESSING_AUTHKEY"] = (
multiprocessing.process.current_process().authkey.hex()
)

def setup(self, worker: Worker) -> None:
import multiprocessing.process

multiprocessing.process.current_process().authkey = bytes.fromhex(
os.environ["_DASK_MULTIPROCESSING_AUTHKEY"]
)
Loading