|
14 | 14 |
|
15 | 15 | import asyncio |
16 | 16 | import concurrent.futures as futures |
| 17 | +import contextlib |
17 | 18 | import logging.config |
18 | 19 | import multiprocessing |
19 | 20 | import os |
20 | 21 | import random |
21 | 22 | import signal |
22 | 23 | import sys |
| 24 | +import threading |
23 | 25 | import uuid |
24 | 26 | from dataclasses import dataclass |
25 | 27 | from types import TracebackType |
@@ -63,6 +65,37 @@ def _mp_kill(self): |
63 | 65 | BaseProcess.kill = _mp_kill |
64 | 66 |
|
65 | 67 | logger = logging.getLogger(__name__) |
| 68 | +_init_main_suspended_local = threading.local() |
| 69 | + |
| 70 | + |
| 71 | +def _patch_spawn_get_preparation_data(): |
| 72 | + try: |
| 73 | + from multiprocessing import spawn as mp_spawn |
| 74 | + |
| 75 | + _raw_get_preparation_data = mp_spawn.get_preparation_data |
| 76 | + |
| 77 | + def _patched_get_preparation_data(*args, **kw): |
| 78 | + ret = _raw_get_preparation_data(*args, **kw) |
| 79 | + if getattr(_init_main_suspended_local, "value", False): |
| 80 | + # make sure user module is not imported when start Mars cluster |
| 81 | + ret.pop("init_main_from_name", None) |
| 82 | + ret.pop("init_main_from_path", None) |
| 83 | + return ret |
| 84 | + |
| 85 | + _patched_get_preparation_data._mars_patched = True |
| 86 | + if not getattr(mp_spawn.get_preparation_data, "_mars_patched", False): |
| 87 | + mp_spawn.get_preparation_data = _patched_get_preparation_data |
| 88 | + except (ImportError, AttributeError): # pragma: no cover |
| 89 | + pass |
| 90 | + |
| 91 | + |
| 92 | +@contextlib.contextmanager |
| 93 | +def _suspend_init_main(): |
| 94 | + try: |
| 95 | + _init_main_suspended_local.value = True |
| 96 | + yield |
| 97 | + finally: |
| 98 | + _init_main_suspended_local.value = False |
66 | 99 |
|
67 | 100 |
|
68 | 101 | @dataslots |
@@ -131,21 +164,25 @@ async def start_sub_pool( |
131 | 164 | def start_pool_in_process(): |
132 | 165 | ctx = multiprocessing.get_context(method=start_method) |
133 | 166 | status_queue = ctx.Queue() |
134 | | - process = ctx.Process( |
135 | | - target=cls._start_sub_pool, |
136 | | - args=(actor_pool_config, process_index, status_queue), |
137 | | - name=f"MarsActorPool{process_index}", |
138 | | - ) |
139 | | - process.daemon = True |
140 | | - process.start() |
| 167 | + |
| 168 | + with _suspend_init_main(): |
| 169 | + process = ctx.Process( |
| 170 | + target=cls._start_sub_pool, |
| 171 | + args=(actor_pool_config, process_index, status_queue), |
| 172 | + name=f"MarsActorPool{process_index}", |
| 173 | + ) |
| 174 | + process.daemon = True |
| 175 | + process.start() |
| 176 | + |
141 | 177 | # wait for sub actor pool to finish starting |
142 | 178 | process_status = status_queue.get() |
143 | 179 | return process, process_status |
144 | 180 |
|
| 181 | + _patch_spawn_get_preparation_data() |
145 | 182 | loop = asyncio.get_running_loop() |
146 | | - executor = futures.ThreadPoolExecutor(1) |
147 | | - create_pool_task = loop.run_in_executor(executor, start_pool_in_process) |
148 | | - return await create_pool_task |
| 183 | + with futures.ThreadPoolExecutor(1) as executor: |
| 184 | + create_pool_task = loop.run_in_executor(executor, start_pool_in_process) |
| 185 | + return await create_pool_task |
149 | 186 |
|
150 | 187 | @classmethod |
151 | 188 | async def wait_sub_pools_ready(cls, create_pool_tasks: List[asyncio.Task]): |
@@ -240,8 +277,11 @@ async def kill_sub_pool( |
240 | 277 | pass |
241 | 278 | process.terminate() |
242 | 279 | wait_pool = futures.ThreadPoolExecutor(1) |
243 | | - loop = asyncio.get_running_loop() |
244 | | - await loop.run_in_executor(wait_pool, process.join, 3) |
| 280 | + try: |
| 281 | + loop = asyncio.get_running_loop() |
| 282 | + await loop.run_in_executor(wait_pool, process.join, 3) |
| 283 | + finally: |
| 284 | + wait_pool.shutdown(False) |
245 | 285 | process.kill() |
246 | 286 | await asyncio.to_thread(process.join, 5) |
247 | 287 |
|
|
0 commit comments