Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import inspect
import itertools
import logging
import threading
from abc import abstractproperty

from dataclasses import dataclass
Expand Down Expand Up @@ -49,6 +50,8 @@
)
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.config import configure
from monarch._rust_bindings.monarch_hyperactor.context import Instance as HyInstance
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
Mailbox,
Expand Down Expand Up @@ -230,6 +233,39 @@ def context() -> Context:
return c


_transport: Optional[ChannelTransport] = None
_transport_lock = threading.Lock()


def enable_transport(transport: ChannelTransport) -> None:
"""
Allow monarch to communicate with transport type 'transport'
This must be called before any other calls in the monarch API.
If it isn't called, we will implicitly call
`monarch.enable_transport(ChannelTransport.Unix)` on the first monarch call.

Currently only one transport type may be enabled at one time.
In the future we may allow multiple to be enabled.
"""
if _context.get(None) is not None:
raise RuntimeError(
"`enable_transport()` must be called before any other calls in the monarch API. "
"If it isn't called, we will implicitly call `monarch.enable_transport(ChannelTransport.Unix)` "
"on the first monarch call."
)

global _transport
with _transport_lock:
if _transport is not None and _transport != transport:
raise RuntimeError(
f"Only one transport type may be enabled at one time. "
f"Currently enabled transport type is `{_transport}`. "
f"Attempted to enable transport type `{transport}`."
)
_transport = transport
configure(default_transport=transport)


@dataclass
class DebugContext:
pdb_wrapper: Optional[PdbWrapper] = None
Expand Down
12 changes: 9 additions & 3 deletions python/monarch/_src/actor/debugger/debug_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import functools
from typing import Dict, List, Optional, Tuple

from monarch._src.actor.actor_mesh import Actor
from monarch._src.actor.actor_mesh import Actor, context
from monarch._src.actor.debugger.debug_command import (
Attach,
Cast,
Expand All @@ -33,8 +33,11 @@
)
from monarch._src.actor.debugger.pdb_wrapper import DebuggerWrite
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.proc_mesh import get_or_spawn_controller
from monarch._src.actor.proc_mesh import (
get_or_spawn_controller as get_or_spawn_controller_v0,
)
from monarch._src.actor.sync_state import fake_sync_state
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh
from monarch.tools.debug_env import (
_get_debug_server_host,
_get_debug_server_port,
Expand Down Expand Up @@ -243,4 +246,7 @@ async def debugger_write(
@functools.cache
def debug_controller() -> DebugController:
with fake_sync_state():
return get_or_spawn_controller("debug_controller", DebugController).get()
if isinstance(context().actor_instance.proc_mesh, ProcMesh):
return get_or_spawn_controller("debug_controller", DebugController).get()
else:
return get_or_spawn_controller_v0("debug_controller", DebugController).get()
16 changes: 13 additions & 3 deletions python/monarch/_src/actor/source_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
import importlib.abc
import linecache

from monarch._src.actor.actor_mesh import _context, Actor
from monarch._src.actor.actor_mesh import _context, Actor, context
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.proc_mesh import get_or_spawn_controller
from monarch._src.actor.proc_mesh import (
get_or_spawn_controller as get_or_spawn_controller_v0,
)
from monarch._src.actor.sync_state import fake_sync_state
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh


class SourceLoaderController(Actor):
Expand All @@ -25,7 +28,14 @@ def get_source(self, filename: str) -> str:
@functools.cache
def source_loader_controller() -> SourceLoaderController:
with fake_sync_state():
return get_or_spawn_controller("source_loader", SourceLoaderController).get()
if isinstance(context().actor_instance.proc_mesh, ProcMesh):
return get_or_spawn_controller(
"source_loader", SourceLoaderController
).get()
else:
return get_or_spawn_controller_v0(
"source_loader", SourceLoaderController
).get()


@functools.cache
Expand Down
14 changes: 8 additions & 6 deletions python/monarch/_src/actor/telemetry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ class TracingForwarder(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# Try to add actor_id from the current context to the logging record
try:
from monarch._src.actor.actor_mesh import context

ctx = context()
if ctx and ctx.actor_instance and ctx.actor_instance.actor_id:
# Add actor_id as an attribute to the logging record
setattr(record, "actor_id", str(ctx.actor_instance.actor_id))
from monarch._src.actor.actor_mesh import _context, context

# Don't initialize the context if it hasn't been initialized yet.
if _context.get(None) is not None:
ctx = context()
if ctx and ctx.actor_instance and ctx.actor_instance.actor_id:
# Add actor_id as an attribute to the logging record
setattr(record, "actor_id", str(ctx.actor_instance.actor_id))
except Exception:
# If we can't get the context or actor_id for any reason, just continue
# without adding the actor_id field
Expand Down
2 changes: 1 addition & 1 deletion python/monarch/_src/actor/v1/host_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def spawn_procs(
name = ""

return self._spawn_nonblocking(
name, Extent(list(per_host.keys()), list(per_host.values())), setup, False
name, Extent(list(per_host.keys()), list(per_host.values())), setup, True
)

def _spawn_nonblocking(
Expand Down
21 changes: 16 additions & 5 deletions python/monarch/_src/actor/v1/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from weakref import WeakSet

from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
from monarch._rust_bindings.monarch_hyperactor.shape import Region, Shape, Slice
from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Region, Shape, Slice

from monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh import (
ProcMesh as HyProcMesh,
Expand Down Expand Up @@ -400,7 +400,9 @@ def get_or_spawn(
if name not in self._controllers:
from monarch._src.actor.v1.host_mesh import this_proc

self._controllers[name] = this_proc().spawn(name, Class, *args, **kwargs)
proc = this_proc()
proc._controller_controller = _get_controller_controller()[1]
self._controllers[name] = proc.spawn(name, Class, *args, **kwargs)
return cast(TActor, self._controllers[name])


Expand All @@ -422,11 +424,16 @@ def _get_controller_controller() -> "Tuple[ProcMesh, _ControllerController]":

_cc_proc_mesh = fake_in_process_host(
"controller_controller_host"
).spawn_procs(name="controller_controller_proc")
)._spawn_nonblocking(
name="controller_controller_proc",
per_host=Extent([], []),
setup=None,
_attach_controller_controller=False,
)
_controller_controller = _cc_proc_mesh.spawn(
"controller_controller", _ControllerController
)
assert _cc_proc_mesh is not None
assert _cc_proc_mesh is not None and _controller_controller is not None
return _cc_proc_mesh, _controller_controller


Expand All @@ -447,7 +454,11 @@ def get_or_spawn_controller(
A Future that resolves to a reference to the actor.
"""
cc = context().actor_instance._controller_controller
if not isinstance(cc, _ControllerController):
if (
cc is not None
and cast(ActorMesh[_ControllerController], cc)._class
is not _ControllerController
):
# This can happen in the client process
cc = _get_controller_controller()[1]
return cc.get_or_spawn.call_one(name, Class, *args, **kwargs)
3 changes: 3 additions & 0 deletions python/monarch/actor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Monarch Actor API - Public interface for actor functionality.
"""

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.shape import Extent
from monarch._src.actor.actor_mesh import (
Accumulator,
Expand All @@ -20,6 +21,7 @@
current_actor_name,
current_rank,
current_size,
enable_transport,
Endpoint,
Point,
Port,
Expand Down Expand Up @@ -77,4 +79,5 @@
"Extent",
"run_worker_loop_forever",
"attach_to_workers",
"enable_transport",
]
Loading