Skip to content

Commit 1c569eb

Browse files
committed
[monarch] Enable pickling HostMesh and ProcMesh from inside tokio thread in certain cases
Pull Request resolved: #1473 Previously, both `HostMesh` and `ProcMesh` stored their underlying rust mesh in an asynchronous `Shared[...]` pytokio object, so any attempts to access the rust meshes from a tokio thread would need to be called from a coroutine. This is a problem for pickling, which needs to be synchronous, and would therefore have to call `Shared[...].block_on()` to get the underlying rust mesh. This diff makes it so that when the internal `Shared[...]` task for `HostMesh` and `ProcMesh` completes, the result is stored so that it can be accessed without the use of `block_on()`. This enables pickling `HostMesh` and `ProcMesh` inside tokio threads, as long as their backing rust meshes have finished initializing -- this will always be the case inside actor endpoints, since the `HostMesh` and `ProcMesh` both need to be done initializing by the time the actor endpoint runs. ghstack-source-id: 315174591 @exported-using-ghexport Differential Revision: [D84195494](https://our.internmc.facebook.com/intern/diff/D84195494/)
1 parent cc89acb commit 1c569eb

File tree

4 files changed

+85
-8
lines changed

4 files changed

+85
-8
lines changed

python/monarch/_src/actor/bootstrap.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,9 @@ def attach_to_workers(
114114
host_mesh: PythonTask[HyHostMesh] = _attach_to_workers(workers_tasks, name=name)
115115
extent = Extent(["hosts"], [len(workers)])
116116
return HostMesh(
117-
host_mesh.spawn(), extent.region, stream_logs=True, is_fake_in_process=False
117+
host_mesh.spawn(),
118+
extent.region,
119+
stream_logs=True,
120+
is_fake_in_process=False,
121+
_initialized_hy_host_mesh=None,
118122
)

python/monarch/_src/actor/v1/host_mesh.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,17 @@ def __init__(
9696
region: Region,
9797
stream_logs: bool,
9898
is_fake_in_process: bool,
99+
_initialized_hy_host_mesh: Optional[HyHostMesh],
99100
) -> None:
101+
self._initialized_host_mesh = _initialized_hy_host_mesh
102+
if not self._initialized_host_mesh:
103+
104+
async def task(hy_host_mesh_task: Shared[HyHostMesh]) -> HyHostMesh:
105+
self._initialized_host_mesh = await hy_host_mesh_task
106+
return self._initialized_host_mesh
107+
108+
hy_host_mesh = PythonTask.from_coroutine(task(hy_host_mesh)).spawn()
109+
100110
self._hy_host_mesh = hy_host_mesh
101111
self._region = region
102112
self._stream_logs = stream_logs
@@ -127,6 +137,7 @@ async def task() -> HyHostMesh:
127137
extent.region,
128138
alloc.stream_logs,
129139
isinstance(allocator, LocalAllocator),
140+
None,
130141
)
131142

132143
def spawn_procs(
@@ -187,15 +198,25 @@ def _new_with_shape(self, shape: Shape) -> "HostMesh":
187198
if shape.region == self._region:
188199
return self
189200

201+
initialized_hm: Optional[HyHostMesh] = (
202+
None
203+
if self._initialized_host_mesh is None
204+
else self._initialized_host_mesh.sliced(shape.region)
205+
)
206+
190207
async def task() -> HyHostMesh:
191-
hy_host_mesh = await self._hy_host_mesh
192-
return hy_host_mesh.sliced(shape.region)
208+
return (
209+
initialized_hm
210+
if initialized_hm
211+
else (await self._hy_host_mesh).sliced(shape.region)
212+
)
193213

194214
return HostMesh(
195215
PythonTask.from_coroutine(task()).spawn(),
196216
shape.region,
197217
self.stream_logs,
198218
self.is_fake_in_process,
219+
initialized_hm,
199220
)
200221

201222
@property
@@ -222,11 +243,12 @@ async def task() -> HyHostMesh:
222243
region,
223244
stream_logs,
224245
is_fake_in_process,
246+
hy_host_mesh,
225247
)
226248

227249
def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]:
228250
return HostMesh._from_initialized_hy_host_mesh, (
229-
self._hy_host_mesh.block_on(),
251+
self._initialized_mesh(),
230252
self._region,
231253
self.stream_logs,
232254
self.is_fake_in_process,
@@ -238,12 +260,18 @@ def is_fake_in_process(self) -> bool:
238260

239261
def __eq__(self, other: "HostMesh") -> bool:
240262
return (
241-
self._hy_host_mesh.block_on() == other._hy_host_mesh.block_on()
263+
self._initialized_mesh() == other._initialized_mesh()
242264
and self._region == other._region
243265
and self.stream_logs == other.stream_logs
244266
and self.is_fake_in_process == other.is_fake_in_process
245267
)
246268

269+
def _initialized_mesh(self) -> HyHostMesh:
270+
if self._initialized_host_mesh is None:
271+
self._hy_host_mesh.block_on()
272+
assert self._initialized_host_mesh is not None
273+
return self._initialized_host_mesh
274+
247275

248276
def fake_in_process_host(name: str) -> "HostMesh":
249277
"""

python/monarch/_src/actor/v1/proc_mesh.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,20 @@ def __init__(
8181
host_mesh: "HostMesh",
8282
region: Region,
8383
root_region: Region,
84+
_initialized_hy_proc_mesh: Optional[HyProcMesh],
8485
_device_mesh: Optional["DeviceMesh"] = None,
8586
) -> None:
8687
_proc_mesh_registry.add(self)
88+
89+
self._initialized_proc_mesh = _initialized_hy_proc_mesh
90+
if not self._initialized_proc_mesh:
91+
92+
async def task(hy_proc_mesh_task: Shared[HyProcMesh]) -> HyProcMesh:
93+
self._initialized_proc_mesh = await hy_proc_mesh_task
94+
return self._initialized_proc_mesh
95+
96+
hy_proc_mesh = PythonTask.from_coroutine(task(hy_proc_mesh)).spawn()
97+
8798
self._proc_mesh = hy_proc_mesh
8899
self._host_mesh = host_mesh
89100
self._region = region
@@ -138,14 +149,25 @@ def _new_with_shape(self, shape: Shape) -> "ProcMesh":
138149
else self._maybe_device_mesh._new_with_shape(shape)
139150
)
140151

152+
initialized_pm: Optional[HyProcMesh] = (
153+
None
154+
if self._initialized_proc_mesh is None
155+
else self._initialized_proc_mesh.sliced(shape.region)
156+
)
157+
141158
async def task() -> HyProcMesh:
142-
return (await self._proc_mesh).sliced(shape.region)
159+
return (
160+
initialized_pm
161+
if initialized_pm
162+
else (await self._proc_mesh).sliced(shape.region)
163+
)
143164

144165
return ProcMesh(
145166
PythonTask.from_coroutine(task()).spawn(),
146167
self._host_mesh,
147168
shape.region,
148169
self._root_region,
170+
initialized_pm,
149171
_device_mesh=device_mesh,
150172
)
151173

@@ -188,7 +210,7 @@ def from_host_mesh(
188210
setup: Callable[[], None] | None = None,
189211
_attach_controller_controller: bool = True,
190212
) -> "ProcMesh":
191-
pm = ProcMesh(hy_proc_mesh, host_mesh, region, region)
213+
pm = ProcMesh(hy_proc_mesh, host_mesh, region, region, None)
192214

193215
if _attach_controller_controller:
194216
instance = context().actor_instance
@@ -367,11 +389,14 @@ async def task() -> HyProcMesh:
367389
host_mesh,
368390
region,
369391
root_region,
392+
_initialized_hy_proc_mesh=hy_proc_mesh,
370393
)
371394

372395
def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]:
373396
return ProcMesh._from_initialized_hy_proc_mesh, (
374-
self._proc_mesh.block_on(),
397+
self._initialized_proc_mesh
398+
if self._initialized_proc_mesh
399+
else self._proc_mesh.block_on(),
375400
self._host_mesh,
376401
self._region,
377402
self._root_region,

python/tests/test_proc_mesh.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from typing import cast
1010

11+
import cloudpickle
12+
1113
import pytest
14+
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
1215
from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice
1316
from monarch._src.actor.actor_mesh import Actor, ActorMesh, context, ValueMesh
1417
from monarch._src.actor.endpoint import endpoint
@@ -137,3 +140,20 @@ def test_nested_meshes() -> None:
137140
assert value == point.rank + 1
138141
for point, value in res_1:
139142
assert value == point.rank
143+
144+
145+
@pytest.mark.timeout(60)
146+
async def test_pickle_initialized_proc_mesh_in_tokio_thread() -> None:
147+
host = create_local_host_mesh("host", Extent(["hosts"], [2]))
148+
proc = host.spawn_procs(per_host={"gpus": 2})
149+
150+
async def task():
151+
cloudpickle.dumps(proc)
152+
153+
await proc.initialized
154+
PythonTask.from_coroutine(task()).block_on()
155+
156+
async def task():
157+
cloudpickle.dumps(proc.slice(gpus=0, hosts=0))
158+
159+
PythonTask.from_coroutine(task()).block_on()

0 commit comments

Comments
 (0)