Skip to content

Commit bb49cc8

Browse files
committed
[monarch] Enable pickling HostMesh and ProcMesh from inside tokio thread in certain cases
Previously, both `HostMesh` and `ProcMesh` stored their underlying rust mesh in an asynchronous `Shared[...]` pytokio object, so any attempts to access the rust meshes from a tokio thread would need to be called from a coroutine. This is a problem for pickling, which needs to be synchronous, and would therefore have to call `Shared[...].block_on()` to get the underlying rust mesh. This diff makes it so that when the internal `Shared[...]` task for `HostMesh` and `ProcMesh` completes, the result is stored so that it can be accessed without the use of `block_on()`. This enables pickling `HostMesh` and `ProcMesh` inside tokio threads, as long as their backing rust meshes have finished initializing -- this will always be the case inside actor endpoints, since the `HostMesh` and `ProcMesh` both need to be done initializing by the time the actor endpoint runs. Differential Revision: [D84195494](https://our.internmc.facebook.com/intern/diff/D84195494/) ghstack-source-id: 315055002 Pull Request resolved: #1473
1 parent 9700834 commit bb49cc8

File tree

4 files changed

+81
-10
lines changed

4 files changed

+81
-10
lines changed

python/monarch/_src/actor/bootstrap.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,9 @@ def attach_to_workers(
114114
host_mesh: PythonTask[HyHostMesh] = _attach_to_workers(workers_tasks, name=name)
115115
extent = Extent(["hosts"], [len(workers)])
116116
return HostMesh(
117-
host_mesh.spawn(), extent.region, stream_logs=True, is_fake_in_process=False
117+
host_mesh.spawn(),
118+
extent.region,
119+
stream_logs=True,
120+
is_fake_in_process=False,
121+
_initialized_hy_host_mesh=None,
118122
)

python/monarch/_src/actor/v1/host_mesh.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,13 @@ def __init__(
9696
region: Region,
9797
stream_logs: bool,
9898
is_fake_in_process: bool,
99+
_initialized_hy_host_mesh: Optional[HyHostMesh],
99100
) -> None:
100101
self._hy_host_mesh = hy_host_mesh
101102
self._region = region
102103
self._stream_logs = stream_logs
103104
self._is_fake_in_process = is_fake_in_process
105+
self._initialized_host_mesh = _initialized_hy_host_mesh
104106

105107
@classmethod
106108
def allocate_nonblocking(
@@ -122,13 +124,24 @@ async def task() -> HyHostMesh:
122124
bootstrap_cmd,
123125
)
124126

125-
return cls(
126-
PythonTask.from_coroutine(task()).spawn(),
127+
hy_host_mesh_task = PythonTask.from_coroutine(task()).spawn()
128+
129+
hm: HostMesh = cls(
130+
hy_host_mesh_task,
127131
extent.region,
128132
alloc.stream_logs,
129133
isinstance(allocator, LocalAllocator),
134+
None,
130135
)
131136

137+
async def task(hy_host_mesh_task: Shared[HyHostMesh]) -> HyHostMesh:
138+
hm._initialized_host_mesh = await hy_host_mesh_task
139+
return hm._initialized_host_mesh
140+
141+
hm._hy_host_mesh = PythonTask.from_coroutine(task(hy_host_mesh_task)).spawn()
142+
143+
return hm
144+
132145
def spawn_procs(
133146
self,
134147
per_host: Dict[str, int] | None = None,
@@ -187,15 +200,25 @@ def _new_with_shape(self, shape: Shape) -> "HostMesh":
187200
if shape.region == self._region:
188201
return self
189202

203+
initialized_hm: Optional[HyHostMesh] = (
204+
None
205+
if self._initialized_host_mesh is None
206+
else self._initialized_host_mesh.sliced(shape.region)
207+
)
208+
190209
async def task() -> HyHostMesh:
191-
hy_host_mesh = await self._hy_host_mesh
192-
return hy_host_mesh.sliced(shape.region)
210+
return (
211+
initialized_hm
212+
if initialized_hm
213+
else (await self._hy_host_mesh).sliced(shape.region)
214+
)
193215

194216
return HostMesh(
195217
PythonTask.from_coroutine(task()).spawn(),
196218
shape.region,
197219
self.stream_logs,
198220
self.is_fake_in_process,
221+
initialized_hm,
199222
)
200223

201224
@property
@@ -222,11 +245,12 @@ async def task() -> HyHostMesh:
222245
region,
223246
stream_logs,
224247
is_fake_in_process,
248+
hy_host_mesh,
225249
)
226250

227251
def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]:
228252
return HostMesh._from_initialized_hy_host_mesh, (
229-
self._hy_host_mesh.block_on(),
253+
self._initialized_mesh(),
230254
self._region,
231255
self.stream_logs,
232256
self.is_fake_in_process,
@@ -238,12 +262,18 @@ def is_fake_in_process(self) -> bool:
238262

239263
def __eq__(self, other: "HostMesh") -> bool:
240264
return (
241-
self._hy_host_mesh.block_on() == other._hy_host_mesh.block_on()
265+
self._initialized_mesh() == other._initialized_mesh()
242266
and self._region == other._region
243267
and self.stream_logs == other.stream_logs
244268
and self.is_fake_in_process == other.is_fake_in_process
245269
)
246270

271+
def _initialized_mesh(self) -> HyHostMesh:
272+
if self._initialized_host_mesh is None:
273+
self._hy_host_mesh.block_on()
274+
assert self._initialized_host_mesh is not None
275+
return self._initialized_host_mesh
276+
247277

248278
def fake_in_process_host(name: str) -> "HostMesh":
249279
"""

python/monarch/_src/actor/v1/proc_mesh.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
host_mesh: "HostMesh",
8282
region: Region,
8383
root_region: Region,
84+
_initialized_hy_proc_mesh: Optional[HyProcMesh],
8485
_device_mesh: Optional["DeviceMesh"] = None,
8586
) -> None:
8687
_proc_mesh_registry.add(self)
@@ -91,6 +92,7 @@ def __init__(
9192
self._maybe_device_mesh = _device_mesh
9293
self._logging_manager = LoggingManager()
9394
self._controller_controller: Optional["_ControllerController"] = None
95+
self._initialized_proc_mesh = _initialized_hy_proc_mesh
9496

9597
@property
9698
def initialized(self) -> Future[Literal[True]]:
@@ -138,14 +140,25 @@ def _new_with_shape(self, shape: Shape) -> "ProcMesh":
138140
else self._maybe_device_mesh._new_with_shape(shape)
139141
)
140142

143+
initialized_pm: Optional[HyProcMesh] = (
144+
None
145+
if self._initialized_proc_mesh is None
146+
else self._initialized_proc_mesh.sliced(shape.region)
147+
)
148+
141149
async def task() -> HyProcMesh:
142-
return (await self._proc_mesh).sliced(shape.region)
150+
return (
151+
initialized_pm
152+
if initialized_pm
153+
else (await self._proc_mesh).sliced(shape.region)
154+
)
143155

144156
return ProcMesh(
145157
PythonTask.from_coroutine(task()).spawn(),
146158
self._host_mesh,
147159
shape.region,
148160
self._root_region,
161+
initialized_pm,
149162
_device_mesh=device_mesh,
150163
)
151164

@@ -188,7 +201,7 @@ def from_host_mesh(
188201
setup: Callable[[], None] | None = None,
189202
_attach_controller_controller: bool = True,
190203
) -> "ProcMesh":
191-
pm = ProcMesh(hy_proc_mesh, host_mesh, region, region)
204+
pm = ProcMesh(hy_proc_mesh, host_mesh, region, region, None)
192205

193206
if _attach_controller_controller:
194207
instance = context().actor_instance
@@ -211,6 +224,7 @@ async def task(
211224
stream_log_to_client: bool,
212225
) -> HyProcMesh:
213226
hy_proc_mesh = await hy_proc_mesh_task
227+
pm._initialized_proc_mesh = hy_proc_mesh
214228

215229
await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client)
216230

@@ -367,11 +381,14 @@ async def task() -> HyProcMesh:
367381
host_mesh,
368382
region,
369383
root_region,
384+
_initialized_hy_proc_mesh=hy_proc_mesh,
370385
)
371386

372387
def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]:
373388
return ProcMesh._from_initialized_hy_proc_mesh, (
374-
self._proc_mesh.block_on(),
389+
self._initialized_proc_mesh
390+
if self._initialized_proc_mesh
391+
else self._proc_mesh.block_on(),
375392
self._host_mesh,
376393
self._region,
377394
self._root_region,

python/tests/test_proc_mesh.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from typing import cast
1010

11+
import cloudpickle
12+
1113
import pytest
14+
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
1215
from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice
1316
from monarch._src.actor.actor_mesh import Actor, ActorMesh, context, ValueMesh
1417
from monarch._src.actor.endpoint import endpoint
@@ -137,3 +140,20 @@ def test_nested_meshes() -> None:
137140
assert value == point.rank + 1
138141
for point, value in res_1:
139142
assert value == point.rank
143+
144+
145+
@pytest.mark.timeout(60)
146+
async def test_pickle_initialized_proc_mesh_in_tokio_thread() -> None:
147+
host = create_local_host_mesh("host", Extent(["hosts"], [2]))
148+
proc = host.spawn_procs(per_host={"gpus": 2})
149+
150+
async def task():
151+
cloudpickle.dumps(proc)
152+
153+
await proc.initialized
154+
PythonTask.from_coroutine(task()).block_on()
155+
156+
async def task():
157+
cloudpickle.dumps(proc.slice(gpus=0, hosts=0))
158+
159+
PythonTask.from_coroutine(task()).block_on()

0 commit comments

Comments
 (0)