Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions docs/source/examples/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@
# -----------------------------------
# To debug an actor, simply define your python actor and insert typical breakpoints
# in the relevant endpoint that you want to debug using Python's built-in ``breakpoint()``.
#
# **Note: There is a known bug where breakpoints will not work if they are defined inside actors
# spawned on a proc mesh that was allocated from inside a different proc mesh. This will be
# resolved in the near future.**

from monarch.actor import Actor, current_rank, endpoint, this_host
from monarch._src.actor.v1.host_mesh import this_host
from monarch.actor import Actor, current_rank, endpoint


def _bad_rank():
Expand Down
32 changes: 25 additions & 7 deletions hyperactor_mesh/src/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::fmt;
use std::ops::Deref;
use std::panic::Location;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;

Expand Down Expand Up @@ -153,14 +154,19 @@ pub(crate) fn get_global_supervision_sink() -> Option<PortHandle<ActorSupervisio
sink_cell().read().unwrap().clone()
}

/// Context use by root client to send messages.
/// Context used by root client to send messages.
/// This mailbox allows us to open ports before we know which proc the
/// messages will be sent to.
pub fn global_root_client() -> &'static Instance<()> {
static GLOBAL_INSTANCE: OnceLock<(Instance<()>, ActorHandle<()>)> = OnceLock::new();
&GLOBAL_INSTANCE.get_or_init(|| {
///
/// Although the current client is stored in a static variable, it is
/// reinitialized every time the default transport changes.
pub fn global_root_client() -> Arc<Instance<()>> {
static GLOBAL_INSTANCE: OnceLock<
Mutex<(Arc<Instance<()>>, ActorHandle<()>, ChannelTransport)>,
> = OnceLock::new();
let init_fn = |transport: ChannelTransport| {
let client_proc = Proc::direct_with_default(
ChannelAddr::any(default_transport()),
ChannelAddr::any(transport.clone()),
"mesh_root_client_proc".into(),
router::global().clone().boxed(),
)
Expand Down Expand Up @@ -203,8 +209,20 @@ pub fn global_root_client() -> &'static Instance<()> {
},
);

(client, handle)
}).0
(Arc::new(client), handle, transport)
};

let mut instance_lock = GLOBAL_INSTANCE
.get_or_init(|| Mutex::new(init_fn(default_transport())))
.lock()
.unwrap();
let new_transport = default_transport();
let old_transport = instance_lock.2.clone();
if old_transport != new_transport {
*instance_lock = init_fn(new_transport);
}

instance_lock.0.clone()
}

type ActorEventRouter = Arc<DashMap<ActorMeshName, mpsc::UnboundedSender<ActorSupervisionEvent>>>;
Expand Down
6 changes: 5 additions & 1 deletion monarch_hyperactor/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ impl<I: Into<ContextInstance>> From<I> for PyInstance {
pub(crate) struct PyContext {
instance: Py<PyInstance>,
rank: Point,
#[pyo3(get)]
is_root_client: bool,
}

#[pymethods]
Expand All @@ -162,10 +164,11 @@ impl PyContext {
#[staticmethod]
fn _root_client_context(py: Python<'_>) -> PyResult<PyContext> {
let _guard = runtime::get_tokio_runtime().enter();
let instance: PyInstance = global_root_client().into();
let instance: PyInstance = global_root_client().as_ref().into();
Ok(PyContext {
instance: instance.into_pyobject(py)?.into(),
rank: Extent::unity().point_of_rank(0).unwrap(),
is_root_client: true,
})
}
}
Expand All @@ -178,6 +181,7 @@ impl PyContext {
PyContext {
instance,
rank: cx.cast_point(),
is_root_client: false,
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions python/monarch/_rust_bindings/monarch_hyperactor/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Type hints for the monarch_hyperactor.config Rust bindings.
"""

from typing import Any, Dict
from typing import Any, Dict, Optional

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport

Expand All @@ -24,6 +24,11 @@ def reload_config_from_env() -> None:
...

def configure(
default_transport: ChannelTransport = ChannelTransport.Unix,
) -> None: ...
default_transport: Optional[ChannelTransport] = None,
) -> None:
"""
Configure typed key-value pairs in the hyperactor global configuration.
"""
...

def get_configuration() -> Dict[str, Any]: ...
29 changes: 28 additions & 1 deletion python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
UndeliverableMessageEnvelope,
)
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
from monarch._rust_bindings.monarch_hyperactor.pytokio import (
is_tokio_thread,
PythonTask,
Shared,
)
from monarch._rust_bindings.monarch_hyperactor.selection import (
Selection as HySelection, # noqa: F401
)
Expand Down Expand Up @@ -206,6 +210,13 @@ def message_rank(self) -> Point:
@staticmethod
def _root_client_context() -> "Context": ...

@property
def is_root_client(self) -> bool:
"""
Whether this is the root client context.
"""
...


_context: contextvars.ContextVar[Context] = contextvars.ContextVar(
"monarch.actor_mesh._context"
Expand All @@ -227,6 +238,22 @@ def context() -> Context:
)

c.actor_instance.proc_mesh._host_mesh = create_local_host_mesh() # type: ignore
# If we are in the root client, and the default transport has changed, then the
# root client context needs to be updated. However, if this is called from a
# pytokio PythonTask, it isn't safe to update the root client context and we need
# to return the original context.
elif c.is_root_client and not is_tokio_thread():
root_client = Context._root_client_context()
if c.actor_instance.actor_id != root_client.actor_instance.actor_id:
c = root_client
_context.set(c)

# This path is only relevant to the v1 APIs
from monarch._src.actor.v1.proc_mesh import _get_controller_controller

c.actor_instance.proc_mesh, c.actor_instance._controller_controller = (
_get_controller_controller(force_respawn=True)
)
return c


Expand Down
12 changes: 9 additions & 3 deletions python/monarch/_src/actor/debugger/debug_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import functools
from typing import Dict, List, Optional, Tuple

from monarch._src.actor.actor_mesh import Actor
from monarch._src.actor.actor_mesh import Actor, context
from monarch._src.actor.debugger.debug_command import (
Attach,
Cast,
Expand All @@ -33,8 +33,11 @@
)
from monarch._src.actor.debugger.pdb_wrapper import DebuggerWrite
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.proc_mesh import get_or_spawn_controller
from monarch._src.actor.proc_mesh import (
get_or_spawn_controller as get_or_spawn_controller_v0,
)
from monarch._src.actor.sync_state import fake_sync_state
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh
from monarch.tools.debug_env import (
_get_debug_server_host,
_get_debug_server_port,
Expand Down Expand Up @@ -243,4 +246,7 @@ async def debugger_write(
@functools.cache
def debug_controller() -> DebugController:
with fake_sync_state():
return get_or_spawn_controller("debug_controller", DebugController).get()
if isinstance(context().actor_instance.proc_mesh, ProcMesh):
return get_or_spawn_controller("debug_controller", DebugController).get()
else:
return get_or_spawn_controller_v0("debug_controller", DebugController).get()
16 changes: 13 additions & 3 deletions python/monarch/_src/actor/source_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
import importlib.abc
import linecache

from monarch._src.actor.actor_mesh import _context, Actor
from monarch._src.actor.actor_mesh import _context, Actor, context
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.proc_mesh import get_or_spawn_controller
from monarch._src.actor.proc_mesh import (
get_or_spawn_controller as get_or_spawn_controller_v0,
)
from monarch._src.actor.sync_state import fake_sync_state
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh


class SourceLoaderController(Actor):
Expand All @@ -25,7 +28,14 @@ def get_source(self, filename: str) -> str:
@functools.cache
def source_loader_controller() -> SourceLoaderController:
with fake_sync_state():
return get_or_spawn_controller("source_loader", SourceLoaderController).get()
if isinstance(context().actor_instance.proc_mesh, ProcMesh):
return get_or_spawn_controller(
"source_loader", SourceLoaderController
).get()
else:
return get_or_spawn_controller_v0(
"source_loader", SourceLoaderController
).get()


@functools.cache
Expand Down
2 changes: 1 addition & 1 deletion python/monarch/_src/actor/v1/host_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def spawn_procs(
name = ""

return self._spawn_nonblocking(
name, Extent(list(per_host.keys()), list(per_host.values())), setup, False
name, Extent(list(per_host.keys()), list(per_host.values())), setup, True
)

def _spawn_nonblocking(
Expand Down
29 changes: 22 additions & 7 deletions python/monarch/_src/actor/v1/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from weakref import WeakSet

from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
from monarch._rust_bindings.monarch_hyperactor.shape import Region, Shape, Slice
from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Region, Shape, Slice

from monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh import (
ProcMesh as HyProcMesh,
Expand Down Expand Up @@ -372,7 +372,9 @@ def get_or_spawn(
if name not in self._controllers:
from monarch._src.actor.v1.host_mesh import this_proc

self._controllers[name] = this_proc().spawn(name, Class, *args, **kwargs)
proc = this_proc()
proc._controller_controller = _get_controller_controller()[1]
self._controllers[name] = proc.spawn(name, Class, *args, **kwargs)
return cast(TActor, self._controllers[name])


Expand All @@ -386,19 +388,28 @@ def get_or_spawn(
# otherwise two initializing procs will both try to init resulting in duplicates. The critical
# region is not blocking: it spawns a separate task to do the init, assigns the
# Shared[_ControllerController] from that task to the global and releases the lock.
def _get_controller_controller() -> "Tuple[ProcMesh, _ControllerController]":
def _get_controller_controller(
force_respawn: bool = False,
) -> "Tuple[ProcMesh, _ControllerController]":
global _controller_controller, _cc_proc_mesh
with _cc_init:
if _controller_controller is None:
if context().is_root_client and (
_controller_controller is None or force_respawn
):
from monarch._src.actor.v1.host_mesh import fake_in_process_host

_cc_proc_mesh = fake_in_process_host(
"controller_controller_host"
).spawn_procs(name="controller_controller_proc")
)._spawn_nonblocking(
name="controller_controller_proc",
per_host=Extent([], []),
setup=None,
_attach_controller_controller=False,
)
_controller_controller = _cc_proc_mesh.spawn(
"controller_controller", _ControllerController
)
assert _cc_proc_mesh is not None
assert _cc_proc_mesh is not None and _controller_controller is not None
return _cc_proc_mesh, _controller_controller


Expand All @@ -419,7 +430,11 @@ def get_or_spawn_controller(
A Future that resolves to a reference to the actor.
"""
cc = context().actor_instance._controller_controller
if not isinstance(cc, _ControllerController):
if (
cc is not None
and cast(ActorMesh[_ControllerController], cc)._class
is not _ControllerController
):
# This can happen in the client process
cc = _get_controller_controller()[1]
return cc.get_or_spawn.call_one(name, Class, *args, **kwargs)
8 changes: 8 additions & 0 deletions python/monarch/actor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
Monarch Actor API - Public interface for actor functionality.
"""

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.config import configure
from monarch._rust_bindings.monarch_hyperactor.shape import Extent
from monarch._src.actor.actor_mesh import (
Accumulator,
Expand Down Expand Up @@ -45,6 +47,9 @@
ProcMesh,
sim_proc_mesh,
)
from monarch._src.actor.v1.proc_mesh import (
get_or_spawn_controller as get_or_spawn_controller_v1,
)

__all__ = [
"Accumulator",
Expand Down Expand Up @@ -77,4 +82,7 @@
"Extent",
"run_worker_loop_forever",
"attach_to_workers",
"get_or_spawn_controller_v1",
"configure",
"ChannelTransport",
]
Loading
Loading