diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 7bc2220c64..f1f7decccc 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -1102,3 +1102,41 @@ def setup(self, worker): def teardown(self, worker): self._exit_stack.close() + + +class MultiprocessingAuthkeyPlugin(WorkerPlugin): + """ + A WorkerPlugin to propagate the main process's ``multiprocessing.current_process().authkey`` + to Dask workers. + + This is necessary when using a ``multiprocessing.Manager`` for communication between the + main process and its workers, especially in distributed settings such as with + ``dask_jobqueue.SLURMCluster``. In standard multiprocessing, the ``authkey`` is automatically + propagated to child processes, but in distributed clusters, this must be done manually. + + This plugin securely forwards the ``authkey`` from the client process to all workers by + setting the environment variable ``_DASK_MULTIPROCESSING_AUTHKEY`` and updating the worker's + ``multiprocessing.current_process().authkey`` accordingly. + + Examples + -------- + >>> from distributed.diagnostics.plugin import MultiprocessingAuthkeyPlugin + >>> client.register_plugin(MultiprocessingAuthkeyPlugin()) + """ + + name = "multiprocessing-authkey" + idempotent = True + + def __init__(self) -> None: + import multiprocessing.process + + os.environ["_DASK_MULTIPROCESSING_AUTHKEY"] = ( + multiprocessing.process.current_process().authkey.hex() + ) + + def setup(self, worker: Worker) -> None: + import multiprocessing.process + + multiprocessing.process.current_process().authkey = bytes.fromhex( + os.environ["_DASK_MULTIPROCESSING_AUTHKEY"] + ) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 7dce7717d7..ef765817bf 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -116,6 +116,7 @@ raises_with_cause, randominc, relative_frame_linenumber, + report_worker_authkey, save_sys_modules, slowadd, slowdec, @@ -8435,3 +8436,22 @@ def reducer(futs, *, offset=0, **kwargs): result = await future.result() assert result == 30 if not offset else 31 + + +async def test_authkey(loop): + # To test that the authkey is propagated to workers by dask, + # we need to start a worker manually. + # (This is because multiprocessing already propagates the authkey.) + async with Scheduler(dashboard_address=":0") as s: + + # Start worker using subprocess + with popen( + [sys.executable, "-m", "distributed.cli.dask_worker", s.address], + ) as proc: + async with Client(s.address, asynchronous=True) as c: + worker_authkey = await c.submit(report_worker_authkey) + + import multiprocessing.process + + # TODO: This currently fails! + assert worker_authkey == multiprocessing.process.current_process().authkey diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 8720dcd87c..bac72bf470 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2599,3 +2599,9 @@ async def padded_time(before=0.05, after=0.05): t = time() await asyncio.sleep(after) return t + + +def report_worker_authkey(): + import multiprocessing.process + + return multiprocessing.process.current_process().authkey