Skip to content

Commit 1cacbce

Browse files
committed
[monarch] monarch.actor.configure, configure default transport anywhere, wire up v1 controllers
This diff does a few things: - Introduce the `monarch.actor.configure` public API, which just wraps `monarch._rust_bindings.monarch_hyperactor.config.configure` function - Update the behavior of `global_root_client()` so that the client proc/instance are re-initialized whenever the default transport changes. - Supporting this also required updating `monarch.actor.context()` to allow overriding the current value of the `_context` contextvar if it is called from the root client and the default transport changed. It also requires forcibly respawning the `_controller_controller` when this happens. - Make sure we pass `_attach_controller_controller=True` when a spawning a proc mesh from v1 `HostMesh` - Fix the v1 implementation of `get_or_spawn_controller` Differential Revision: [D84015780](https://our.internmc.facebook.com/intern/diff/D84015780/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D84015780/)! ghstack-source-id: 314421176 Pull Request resolved: #1446
1 parent bf58a9b commit 1cacbce

File tree

12 files changed

+410
-121
lines changed

12 files changed

+410
-121
lines changed

docs/source/examples/debugging.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,9 @@
2626
# -----------------------------------
2727
# To debug an actor, simply define your python actor and insert typical breakpoints
2828
# in the relevant endpoint that you want to debug using Python's built-in ``breakpoint()``.
29-
#
30-
# **Note: There is a known bug where breakpoints will not work if they are defined inside actors
31-
# spawned on a proc mesh that was allocated from inside a different proc mesh. This will be
32-
# resolved in the near future.**
3329

34-
from monarch.actor import Actor, current_rank, endpoint, this_host
30+
from monarch._src.actor.v1.host_mesh import this_host
31+
from monarch.actor import Actor, current_rank, endpoint
3532

3633

3734
def _bad_rank():

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::fmt;
1111
use std::ops::Deref;
1212
use std::panic::Location;
1313
use std::sync::Arc;
14+
use std::sync::Mutex;
1415
use std::sync::atomic::AtomicUsize;
1516
use std::sync::atomic::Ordering;
1617

@@ -153,14 +154,19 @@ pub(crate) fn get_global_supervision_sink() -> Option<PortHandle<ActorSupervisio
153154
sink_cell().read().unwrap().clone()
154155
}
155156

156-
/// Context use by root client to send messages.
157+
/// Context used by root client to send messages.
157158
/// This mailbox allows us to open ports before we know which proc the
158159
/// messages will be sent to.
159-
pub fn global_root_client() -> &'static Instance<()> {
160-
static GLOBAL_INSTANCE: OnceLock<(Instance<()>, ActorHandle<()>)> = OnceLock::new();
161-
&GLOBAL_INSTANCE.get_or_init(|| {
160+
///
161+
/// Although the current client is stored in a static variable, it is
162+
/// reinitialized every time the default transport changes.
163+
pub fn global_root_client() -> Arc<Instance<()>> {
164+
static GLOBAL_INSTANCE: OnceLock<
165+
Mutex<(Arc<Instance<()>>, ActorHandle<()>, ChannelTransport)>,
166+
> = OnceLock::new();
167+
let init_fn = |transport: ChannelTransport| {
162168
let client_proc = Proc::direct_with_default(
163-
ChannelAddr::any(default_transport()),
169+
ChannelAddr::any(transport.clone()),
164170
"mesh_root_client_proc".into(),
165171
router::global().clone().boxed(),
166172
)
@@ -203,8 +209,20 @@ pub fn global_root_client() -> &'static Instance<()> {
203209
},
204210
);
205211

206-
(client, handle)
207-
}).0
212+
(Arc::new(client), handle, transport)
213+
};
214+
215+
let mut instance_lock = GLOBAL_INSTANCE
216+
.get_or_init(|| Mutex::new(init_fn(default_transport())))
217+
.lock()
218+
.unwrap();
219+
let new_transport = default_transport();
220+
let old_transport = instance_lock.2.clone();
221+
if old_transport != new_transport {
222+
*instance_lock = init_fn(new_transport);
223+
}
224+
225+
instance_lock.0.clone()
208226
}
209227

210228
type ActorEventRouter = Arc<DashMap<ActorMeshName, mpsc::UnboundedSender<ActorSupervisionEvent>>>;

monarch_hyperactor/src/context.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ impl<I: Into<ContextInstance>> From<I> for PyInstance {
145145
pub(crate) struct PyContext {
146146
instance: Py<PyInstance>,
147147
rank: Point,
148+
#[pyo3(get)]
149+
is_root_client: bool,
148150
}
149151

150152
#[pymethods]
@@ -162,10 +164,11 @@ impl PyContext {
162164
#[staticmethod]
163165
fn _root_client_context(py: Python<'_>) -> PyResult<PyContext> {
164166
let _guard = runtime::get_tokio_runtime().enter();
165-
let instance: PyInstance = global_root_client().into();
167+
let instance: PyInstance = global_root_client().as_ref().into();
166168
Ok(PyContext {
167169
instance: instance.into_pyobject(py)?.into(),
168170
rank: Extent::unity().point_of_rank(0).unwrap(),
171+
is_root_client: true,
169172
})
170173
}
171174
}
@@ -178,6 +181,7 @@ impl PyContext {
178181
PyContext {
179182
instance,
180183
rank: cx.cast_point(),
184+
is_root_client: false,
181185
}
182186
}
183187
}

python/monarch/_rust_bindings/monarch_hyperactor/config.pyi

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
Type hints for the monarch_hyperactor.config Rust bindings.
1111
"""
1212

13-
from typing import Any, Dict
13+
from typing import Any, Dict, Optional
1414

1515
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
1616

@@ -24,6 +24,11 @@ def reload_config_from_env() -> None:
2424
...
2525

2626
def configure(
27-
default_transport: ChannelTransport = ChannelTransport.Unix,
28-
) -> None: ...
27+
default_transport: Optional[ChannelTransport] = None,
28+
) -> None:
29+
"""
30+
Configure typed key-value pairs in the hyperactor global configuration.
31+
"""
32+
...
33+
2934
def get_configuration() -> Dict[str, Any]: ...

python/monarch/_src/actor/actor_mesh.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@
5757
UndeliverableMessageEnvelope,
5858
)
5959
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
60-
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
60+
from monarch._rust_bindings.monarch_hyperactor.pytokio import (
61+
is_tokio_thread,
62+
PythonTask,
63+
Shared,
64+
)
6165
from monarch._rust_bindings.monarch_hyperactor.selection import (
6266
Selection as HySelection, # noqa: F401
6367
)
@@ -206,6 +210,13 @@ def message_rank(self) -> Point:
206210
@staticmethod
207211
def _root_client_context() -> "Context": ...
208212

213+
@property
214+
def is_root_client(self) -> bool:
215+
"""
216+
Whether this is the root client context.
217+
"""
218+
...
219+
209220

210221
_context: contextvars.ContextVar[Context] = contextvars.ContextVar(
211222
"monarch.actor_mesh._context"
@@ -227,6 +238,22 @@ def context() -> Context:
227238
)
228239

229240
c.actor_instance.proc_mesh._host_mesh = create_local_host_mesh() # type: ignore
241+
# If we are in the root client, and the default transport has changed, then the
242+
# root client context needs to be updated. However, if this is called from a
243+
# pytokio PythonTask, it isn't safe to update the root client context and we need
244+
# to return the original context.
245+
elif c.is_root_client and not is_tokio_thread():
246+
root_client = Context._root_client_context()
247+
if c.actor_instance.actor_id != root_client.actor_instance.actor_id:
248+
c = root_client
249+
_context.set(c)
250+
251+
# This path is only relevant to the v1 APIs
252+
from monarch._src.actor.v1.proc_mesh import _get_controller_controller
253+
254+
c.actor_instance.proc_mesh, c.actor_instance._controller_controller = (
255+
_get_controller_controller(force_respawn=True)
256+
)
230257
return c
231258

232259

python/monarch/_src/actor/debugger/debug_controller.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import functools
1010
from typing import Dict, List, Optional, Tuple
1111

12-
from monarch._src.actor.actor_mesh import Actor
12+
from monarch._src.actor.actor_mesh import Actor, context
1313
from monarch._src.actor.debugger.debug_command import (
1414
Attach,
1515
Cast,
@@ -33,8 +33,11 @@
3333
)
3434
from monarch._src.actor.debugger.pdb_wrapper import DebuggerWrite
3535
from monarch._src.actor.endpoint import endpoint
36-
from monarch._src.actor.proc_mesh import get_or_spawn_controller
36+
from monarch._src.actor.proc_mesh import (
37+
get_or_spawn_controller as get_or_spawn_controller_v0,
38+
)
3739
from monarch._src.actor.sync_state import fake_sync_state
40+
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh
3841
from monarch.tools.debug_env import (
3942
_get_debug_server_host,
4043
_get_debug_server_port,
@@ -243,4 +246,7 @@ async def debugger_write(
243246
@functools.cache
244247
def debug_controller() -> DebugController:
245248
with fake_sync_state():
246-
return get_or_spawn_controller("debug_controller", DebugController).get()
249+
if isinstance(context().actor_instance.proc_mesh, ProcMesh):
250+
return get_or_spawn_controller("debug_controller", DebugController).get()
251+
else:
252+
return get_or_spawn_controller_v0("debug_controller", DebugController).get()

python/monarch/_src/actor/source_loader.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
import importlib.abc
1111
import linecache
1212

13-
from monarch._src.actor.actor_mesh import _context, Actor
13+
from monarch._src.actor.actor_mesh import _context, Actor, context
1414
from monarch._src.actor.endpoint import endpoint
15-
from monarch._src.actor.proc_mesh import get_or_spawn_controller
15+
from monarch._src.actor.proc_mesh import (
16+
get_or_spawn_controller as get_or_spawn_controller_v0,
17+
)
1618
from monarch._src.actor.sync_state import fake_sync_state
19+
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh
1720

1821

1922
class SourceLoaderController(Actor):
@@ -25,7 +28,14 @@ def get_source(self, filename: str) -> str:
2528
@functools.cache
2629
def source_loader_controller() -> SourceLoaderController:
2730
with fake_sync_state():
28-
return get_or_spawn_controller("source_loader", SourceLoaderController).get()
31+
if isinstance(context().actor_instance.proc_mesh, ProcMesh):
32+
return get_or_spawn_controller(
33+
"source_loader", SourceLoaderController
34+
).get()
35+
else:
36+
return get_or_spawn_controller_v0(
37+
"source_loader", SourceLoaderController
38+
).get()
2939

3040

3141
@functools.cache

python/monarch/_src/actor/v1/host_mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def spawn_procs(
147147
name = ""
148148

149149
return self._spawn_nonblocking(
150-
name, Extent(list(per_host.keys()), list(per_host.values())), setup, False
150+
name, Extent(list(per_host.keys()), list(per_host.values())), setup, True
151151
)
152152

153153
def _spawn_nonblocking(

python/monarch/_src/actor/v1/proc_mesh.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from weakref import WeakSet
2929

3030
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
31-
from monarch._rust_bindings.monarch_hyperactor.shape import Region, Shape, Slice
31+
from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Region, Shape, Slice
3232

3333
from monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh import (
3434
ProcMesh as HyProcMesh,
@@ -372,7 +372,9 @@ def get_or_spawn(
372372
if name not in self._controllers:
373373
from monarch._src.actor.v1.host_mesh import this_proc
374374

375-
self._controllers[name] = this_proc().spawn(name, Class, *args, **kwargs)
375+
proc = this_proc()
376+
proc._controller_controller = _get_controller_controller()[1]
377+
self._controllers[name] = proc.spawn(name, Class, *args, **kwargs)
376378
return cast(TActor, self._controllers[name])
377379

378380

@@ -386,19 +388,28 @@ def get_or_spawn(
386388
# otherwise two initializing procs will both try to init resulting in duplicates. The critical
387389
# region is not blocking: it spawns a separate task to do the init, assigns the
388390
# Shared[_ControllerController] from that task to the global and releases the lock.
389-
def _get_controller_controller() -> "Tuple[ProcMesh, _ControllerController]":
391+
def _get_controller_controller(
392+
force_respawn: bool = False,
393+
) -> "Tuple[ProcMesh, _ControllerController]":
390394
global _controller_controller, _cc_proc_mesh
391395
with _cc_init:
392-
if _controller_controller is None:
396+
if context().is_root_client and (
397+
_controller_controller is None or force_respawn
398+
):
393399
from monarch._src.actor.v1.host_mesh import fake_in_process_host
394400

395401
_cc_proc_mesh = fake_in_process_host(
396402
"controller_controller_host"
397-
).spawn_procs(name="controller_controller_proc")
403+
)._spawn_nonblocking(
404+
name="controller_controller_proc",
405+
per_host=Extent([], []),
406+
setup=None,
407+
_attach_controller_controller=False,
408+
)
398409
_controller_controller = _cc_proc_mesh.spawn(
399410
"controller_controller", _ControllerController
400411
)
401-
assert _cc_proc_mesh is not None
412+
assert _cc_proc_mesh is not None and _controller_controller is not None
402413
return _cc_proc_mesh, _controller_controller
403414

404415

@@ -419,7 +430,11 @@ def get_or_spawn_controller(
419430
A Future that resolves to a reference to the actor.
420431
"""
421432
cc = context().actor_instance._controller_controller
422-
if not isinstance(cc, _ControllerController):
433+
if (
434+
cc is not None
435+
and cast(ActorMesh[_ControllerController], cc)._class
436+
is not _ControllerController
437+
):
423438
# This can happen in the client process
424439
cc = _get_controller_controller()[1]
425440
return cc.get_or_spawn.call_one(name, Class, *args, **kwargs)

python/monarch/actor/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
Monarch Actor API - Public interface for actor functionality.
1010
"""
1111

12+
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
13+
from monarch._rust_bindings.monarch_hyperactor.config import configure
1214
from monarch._rust_bindings.monarch_hyperactor.shape import Extent
1315
from monarch._src.actor.actor_mesh import (
1416
Accumulator,
@@ -45,6 +47,9 @@
4547
ProcMesh,
4648
sim_proc_mesh,
4749
)
50+
from monarch._src.actor.v1.proc_mesh import (
51+
get_or_spawn_controller as get_or_spawn_controller_v1,
52+
)
4853

4954
__all__ = [
5055
"Accumulator",
@@ -77,4 +82,7 @@
7782
"Extent",
7883
"run_worker_loop_forever",
7984
"attach_to_workers",
85+
"get_or_spawn_controller_v1",
86+
"configure",
87+
"ChannelTransport",
8088
]

0 commit comments

Comments
 (0)