Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,3 +1102,41 @@ def setup(self, worker):

def teardown(self, worker):
self._exit_stack.close()


class MultiprocessingAuthkeyPlugin(WorkerPlugin):
"""
A WorkerPlugin to propagate the main process's ``multiprocessing.current_process().authkey``
to Dask workers.

This is necessary when using a ``multiprocessing.Manager`` for communication between the
main process and its workers, especially in distributed settings such as with
``dask_jobqueue.SLURMCluster``. In standard multiprocessing, the ``authkey`` is automatically
propagated to child processes, but in distributed clusters, this must be done manually.

This plugin securely forwards the ``authkey`` from the client process to all workers by
setting the environment variable ``_DASK_MULTIPROCESSING_AUTHKEY`` and updating the worker's
``multiprocessing.current_process().authkey`` accordingly.

Examples
--------
>>> from distributed.diagnostics.plugin import MultiprocessingAuthkeyPlugin
>>> client.register_plugin(MultiprocessingAuthkeyPlugin())
"""

name = "multiprocessing-authkey"
idempotent = True

def __init__(self) -> None:
import multiprocessing.process

os.environ["_DASK_MULTIPROCESSING_AUTHKEY"] = (
multiprocessing.process.current_process().authkey.hex()
)

def setup(self, worker: Worker) -> None:
import multiprocessing.process

multiprocessing.process.current_process().authkey = bytes.fromhex(
os.environ["_DASK_MULTIPROCESSING_AUTHKEY"]
)
20 changes: 20 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
raises_with_cause,
randominc,
relative_frame_linenumber,
report_worker_authkey,
save_sys_modules,
slowadd,
slowdec,
Expand Down Expand Up @@ -8435,3 +8436,22 @@ def reducer(futs, *, offset=0, **kwargs):

result = await future.result()
assert result == 30 if not offset else 31


async def test_authkey(loop):
# To test that the authkey is propagated to workers by dask,
# we need to start a worker manually.
# (This is because multiprocessing already propagates the authkey.)
async with Scheduler(dashboard_address=":0") as s:

# Start worker using subprocess
with popen(
[sys.executable, "-m", "distributed.cli.dask_worker", s.address],
) as proc:
async with Client(s.address, asynchronous=True) as c:
worker_authkey = await c.submit(report_worker_authkey)

import multiprocessing.process

# TODO: This currently fails!
assert worker_authkey == multiprocessing.process.current_process().authkey
6 changes: 6 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2599,3 +2599,9 @@ async def padded_time(before=0.05, after=0.05):
t = time()
await asyncio.sleep(after)
return t


def report_worker_authkey():
import multiprocessing.process

return multiprocessing.process.current_process().authkey
Loading