Skip to content

Commit 6ca383a

Browse files
samluryemeta-codesync[bot]
authored andcommitted
monarch.enable_transport(), and v1 controllers (#1462)
Summary: Pull Request resolved: #1462 This diff does two things: - Introduce `monarch.enable_transport(...)`, which enables the user to set the transport of the root client. For now, only one transport is allowed. It must be called before any other monarch APIs, and it will throw if called multiple times with different transports. - Enable controllers in v1 ghstack-source-id: 314719319 Reviewed By: zdevito, mariusae Differential Revision: D84100520 fbshipit-source-id: 5f82d77c0a39048ee978ff6a6386aa03746396e4
1 parent 1090b0f commit 6ca383a

File tree

9 files changed

+336
-108
lines changed

9 files changed

+336
-108
lines changed

python/monarch/_src/actor/actor_mesh.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import inspect
1414
import itertools
1515
import logging
16+
import threading
1617
from abc import abstractproperty
1718

1819
from dataclasses import dataclass
@@ -49,6 +50,8 @@
4950
)
5051
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
5152
from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer
53+
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
54+
from monarch._rust_bindings.monarch_hyperactor.config import configure
5255
from monarch._rust_bindings.monarch_hyperactor.context import Instance as HyInstance
5356
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
5457
Mailbox,
@@ -230,6 +233,39 @@ def context() -> Context:
230233
return c
231234

232235

236+
_transport: Optional[ChannelTransport] = None
237+
_transport_lock = threading.Lock()
238+
239+
240+
def enable_transport(transport: ChannelTransport) -> None:
241+
"""
242+
Allow monarch to communicate with transport type 'transport'
243+
This must be called before any other calls in the monarch API.
244+
If it isn't called, we will implicitly call
245+
`monarch.enable_transport(ChannelTransport.Unix)` on the first monarch call.
246+
247+
Currently only one transport type may be enabled at one time.
248+
In the future we may allow multiple to be enabled.
249+
"""
250+
if _context.get(None) is not None:
251+
raise RuntimeError(
252+
"`enable_transport()` must be called before any other calls in the monarch API. "
253+
"If it isn't called, we will implicitly call `monarch.enable_transport(ChannelTransport.Unix)` "
254+
"on the first monarch call."
255+
)
256+
257+
global _transport
258+
with _transport_lock:
259+
if _transport is not None and _transport != transport:
260+
raise RuntimeError(
261+
f"Only one transport type may be enabled at one time. "
262+
f"Currently enabled transport type is `{_transport}`. "
263+
f"Attempted to enable transport type `{transport}`."
264+
)
265+
_transport = transport
266+
configure(default_transport=transport)
267+
268+
233269
@dataclass
234270
class DebugContext:
235271
pdb_wrapper: Optional[PdbWrapper] = None

python/monarch/_src/actor/debugger/debug_controller.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import functools
1010
from typing import Dict, List, Optional, Tuple
1111

12-
from monarch._src.actor.actor_mesh import Actor
12+
from monarch._src.actor.actor_mesh import Actor, context
1313
from monarch._src.actor.debugger.debug_command import (
1414
Attach,
1515
Cast,
@@ -33,8 +33,11 @@
3333
)
3434
from monarch._src.actor.debugger.pdb_wrapper import DebuggerWrite
3535
from monarch._src.actor.endpoint import endpoint
36-
from monarch._src.actor.proc_mesh import get_or_spawn_controller
36+
from monarch._src.actor.proc_mesh import (
37+
get_or_spawn_controller as get_or_spawn_controller_v0,
38+
)
3739
from monarch._src.actor.sync_state import fake_sync_state
40+
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh
3841
from monarch.tools.debug_env import (
3942
_get_debug_server_host,
4043
_get_debug_server_port,
@@ -243,4 +246,7 @@ async def debugger_write(
243246
@functools.cache
244247
def debug_controller() -> DebugController:
245248
with fake_sync_state():
246-
return get_or_spawn_controller("debug_controller", DebugController).get()
249+
if isinstance(context().actor_instance.proc_mesh, ProcMesh):
250+
return get_or_spawn_controller("debug_controller", DebugController).get()
251+
else:
252+
return get_or_spawn_controller_v0("debug_controller", DebugController).get()

python/monarch/_src/actor/source_loader.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
import importlib.abc
1111
import linecache
1212

13-
from monarch._src.actor.actor_mesh import _context, Actor
13+
from monarch._src.actor.actor_mesh import _context, Actor, context
1414
from monarch._src.actor.endpoint import endpoint
15-
from monarch._src.actor.proc_mesh import get_or_spawn_controller
15+
from monarch._src.actor.proc_mesh import (
16+
get_or_spawn_controller as get_or_spawn_controller_v0,
17+
)
1618
from monarch._src.actor.sync_state import fake_sync_state
19+
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh
1720

1821

1922
class SourceLoaderController(Actor):
@@ -25,7 +28,14 @@ def get_source(self, filename: str) -> str:
2528
@functools.cache
2629
def source_loader_controller() -> SourceLoaderController:
2730
with fake_sync_state():
28-
return get_or_spawn_controller("source_loader", SourceLoaderController).get()
31+
if isinstance(context().actor_instance.proc_mesh, ProcMesh):
32+
return get_or_spawn_controller(
33+
"source_loader", SourceLoaderController
34+
).get()
35+
else:
36+
return get_or_spawn_controller_v0(
37+
"source_loader", SourceLoaderController
38+
).get()
2939

3040

3141
@functools.cache

python/monarch/_src/actor/telemetry/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ class TracingForwarder(logging.Handler):
3030
def emit(self, record: logging.LogRecord) -> None:
3131
# Try to add actor_id from the current context to the logging record
3232
try:
33-
from monarch._src.actor.actor_mesh import context
34-
35-
ctx = context()
36-
if ctx and ctx.actor_instance and ctx.actor_instance.actor_id:
37-
# Add actor_id as an attribute to the logging record
38-
setattr(record, "actor_id", str(ctx.actor_instance.actor_id))
33+
from monarch._src.actor.actor_mesh import _context, context
34+
35+
# Don't initialize the context if it hasn't been initialized yet.
36+
if _context.get(None) is not None:
37+
ctx = context()
38+
if ctx and ctx.actor_instance and ctx.actor_instance.actor_id:
39+
# Add actor_id as an attribute to the logging record
40+
setattr(record, "actor_id", str(ctx.actor_instance.actor_id))
3941
except Exception:
4042
# If we can't get the context or actor_id for any reason, just continue
4143
# without adding the actor_id field

python/monarch/_src/actor/v1/host_mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def spawn_procs(
142142
name = ""
143143

144144
return self._spawn_nonblocking(
145-
name, Extent(list(per_host.keys()), list(per_host.values())), setup, False
145+
name, Extent(list(per_host.keys()), list(per_host.values())), setup, True
146146
)
147147

148148
def _spawn_nonblocking(

python/monarch/_src/actor/v1/proc_mesh.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from weakref import WeakSet
2929

3030
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
31-
from monarch._rust_bindings.monarch_hyperactor.shape import Region, Shape, Slice
31+
from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Region, Shape, Slice
3232

3333
from monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh import (
3434
ProcMesh as HyProcMesh,
@@ -400,7 +400,9 @@ def get_or_spawn(
400400
if name not in self._controllers:
401401
from monarch._src.actor.v1.host_mesh import this_proc
402402

403-
self._controllers[name] = this_proc().spawn(name, Class, *args, **kwargs)
403+
proc = this_proc()
404+
proc._controller_controller = _get_controller_controller()[1]
405+
self._controllers[name] = proc.spawn(name, Class, *args, **kwargs)
404406
return cast(TActor, self._controllers[name])
405407

406408

@@ -422,11 +424,16 @@ def _get_controller_controller() -> "Tuple[ProcMesh, _ControllerController]":
422424

423425
_cc_proc_mesh = fake_in_process_host(
424426
"controller_controller_host"
425-
).spawn_procs(name="controller_controller_proc")
427+
)._spawn_nonblocking(
428+
name="controller_controller_proc",
429+
per_host=Extent([], []),
430+
setup=None,
431+
_attach_controller_controller=False,
432+
)
426433
_controller_controller = _cc_proc_mesh.spawn(
427434
"controller_controller", _ControllerController
428435
)
429-
assert _cc_proc_mesh is not None
436+
assert _cc_proc_mesh is not None and _controller_controller is not None
430437
return _cc_proc_mesh, _controller_controller
431438

432439

@@ -447,7 +454,11 @@ def get_or_spawn_controller(
447454
A Future that resolves to a reference to the actor.
448455
"""
449456
cc = context().actor_instance._controller_controller
450-
if not isinstance(cc, _ControllerController):
457+
if (
458+
cc is not None
459+
and cast(ActorMesh[_ControllerController], cc)._class
460+
is not _ControllerController
461+
):
451462
# This can happen in the client process
452463
cc = _get_controller_controller()[1]
453464
return cc.get_or_spawn.call_one(name, Class, *args, **kwargs)

python/monarch/actor/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Monarch Actor API - Public interface for actor functionality.
1010
"""
1111

12+
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
1213
from monarch._rust_bindings.monarch_hyperactor.shape import Extent
1314
from monarch._src.actor.actor_mesh import (
1415
Accumulator,
@@ -20,6 +21,7 @@
2021
current_actor_name,
2122
current_rank,
2223
current_size,
24+
enable_transport,
2325
Endpoint,
2426
Point,
2527
Port,
@@ -77,4 +79,5 @@
7779
"Extent",
7880
"run_worker_loop_forever",
7981
"attach_to_workers",
82+
"enable_transport",
8083
]

0 commit comments

Comments
 (0)