Skip to content

Commit 5bfacfe

Browse files
samluryefacebook-github-bot
authored andcommitted
How many controllers would a ControllerController control? (#903)
Summary: Pull Request resolved: #903 This diff introduces the `get_or_spawn_controller(name, Class, *args, **kwargs)` API, which creates a singleton actor of type `Class` indexed by `name` if no such actor with `name` already exists; otherwise, it returns a reference to the existing singleton actor with that `name`. The singleton actor is created on the root client. To facilitate this API and manage all of the singleton controllers, this diff adds the `_ControllerController` singleton, which lives on the root client and is spawned the first time a proc mesh is created. Every time a new actor mesh is created, the mesh receives a reference to the `_ControllerController`, which is added to `MonarchContext` and is accessed inside `get_or_spawn_controller`. With the new `get_or_spawn_controller` functionality, we no longer need to initialize `_RdmaManager` and `DebugManager` (or any manager) inside `proc_mesh.py`. (Actually, `DebugManager` is no longer needed at all, and `DebugClient` has been renamed to `DebugController`). In order to support the `_RdmaManager` pattern, where a controller singleton spawns a manager actor on a user-specified proc mesh, we need a way to pass proc meshes through actor endpoints. This diff implements the exceedingly hacky `ProcMeshRef`, which just contains the proc mesh id, and can't do anything `ProcMesh` can do. However, when passed to an actor endpoint on the root client process, `ProcMeshRef` can be dereferenced into an actual `ProcMesh` by looking up its id in the `_proc_mesh_registry`, which is a global map from proc mesh id to `weakref(ProcMesh)`. Also, `rdma/__init__.pyi` was moved into `rdma.pyi` to resolve some pesky lint issues related to importing `monarch._rust_bindings.rdma`. ghstack-source-id: 303907982 Reviewed By: zdevito Differential Revision: D80374416 fbshipit-source-id: 58373a58454f9986c10f947cde8444858f46acb7
1 parent b1467fd commit 5bfacfe

File tree

10 files changed

+325
-158
lines changed

10 files changed

+325
-158
lines changed

monarch_hyperactor/src/pytokio.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
use std::error::Error;
1010
use std::future::Future;
11-
use std::ops::Deref;
1211
use std::pin::Pin;
1312

1413
use hyperactor::clock::Clock;
@@ -20,6 +19,7 @@ use pyo3::exceptions::PyStopIteration;
2019
use pyo3::exceptions::PyTimeoutError;
2120
use pyo3::exceptions::PyValueError;
2221
use pyo3::prelude::*;
22+
use pyo3::types::PyNone;
2323
use pyo3::types::PyType;
2424
use tokio::sync::Mutex;
2525
use tokio::sync::watch;
@@ -182,8 +182,17 @@ impl PyPythonTask {
182182
}
183183

184184
#[staticmethod]
185-
fn from_coroutine(coro: PyObject) -> PyResult<PyPythonTask> {
186-
PyPythonTask::new(async {
185+
fn from_coroutine(py: Python<'_>, coro: PyObject) -> PyResult<PyPythonTask> {
186+
// MonarchContext.get() used inside a PythonTask should inherit the value of
187+
// MonarchContext from the context in which the PythonTask was constructed.
188+
// We need to do this manually because the value of the contextvar isn't
189+
// maintained inside the tokio runtime.
190+
let monarch_context = py
191+
.import("monarch._src.actor.actor_mesh")?
192+
.getattr("MonarchContext")?
193+
.call_method0("get")?
194+
.unbind();
195+
PyPythonTask::new(async move {
187196
let (coroutine_iterator, none) = Python::with_gil(|py| {
188197
coro.into_bound(py)
189198
.call_method0("__await__")
@@ -196,12 +205,23 @@ impl PyPythonTask {
196205
}
197206
loop {
198207
let action: PyResult<Action> = Python::with_gil(|py| {
208+
// We may be executing in a new thread at this point, so we need to set the value
209+
// of MonarchContext.
210+
let _context = py
211+
.import("monarch._src.actor.actor_mesh")?
212+
.getattr("_context")?;
213+
let old_context = _context.call_method1("get", (PyNone::get(py),))?;
214+
_context.call_method1("set", (monarch_context.clone_ref(py),))?;
215+
199216
let result = match last {
200217
Ok(value) => coroutine_iterator.bind(py).call_method1("send", (value,)),
201218
Err(pyerr) => coroutine_iterator
202219
.bind(py)
203220
.call_method1("throw", (pyerr.into_value(py),)),
204221
};
222+
223+
// Reset MonarchContext so that when this tokio thread yields, it has its original state.
224+
_context.call_method1("set", (old_context,))?;
205225
match result {
206226
Ok(task) => Ok(Action::Wait(
207227
task.extract::<Py<PyPythonTask>>()

python/monarch/_rust_bindings/rdma/__init__.pyi renamed to python/monarch/_rust_bindings/rdma.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, final, Optional
7+
# pyre-strict
8+
from typing import Any, final
89

910
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
1011

@@ -51,7 +52,7 @@ class _RdmaBuffer:
5152
client: Any,
5253
timeout: int,
5354
) -> PythonTask[Any]: ...
54-
def __reduce__(self) -> tuple: ...
55+
def __reduce__(self) -> tuple[Any, ...]: ...
5556
def __repr__(self) -> str: ...
5657
@staticmethod
5758
def new_from_json(json: str) -> _RdmaBuffer: ...

python/monarch/_src/actor/actor_mesh.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
import inspect
1313
import itertools
1414
import logging
15-
import os
1615
import random
1716
import traceback
18-
from abc import ABC, abstractmethod
1917

2018
from dataclasses import dataclass
2119
from traceback import TracebackException
@@ -98,7 +96,7 @@
9896
from monarch._rust_bindings.monarch_hyperactor.actor import PortProtocol
9997
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ActorMeshProtocol
10098
from monarch._rust_bindings.monarch_hyperactor.mailbox import PortReceiverBase
101-
from monarch._src.actor.proc_mesh import ProcMesh
99+
from monarch._src.actor.proc_mesh import _ControllerController, ProcMesh
102100

103101
CallMethod = PythonMessageKind.CallMethod
104102

@@ -127,14 +125,18 @@ class MonarchContext:
127125
proc_id: str
128126
point: Point
129127
send_queue: Tuple[Optional["Shared[Any]"], int]
128+
controller_controller: Optional["_ControllerController"]
129+
proc_mesh: Optional["ProcMesh"] # actually this is a ProcMeshRef under the hood
130130

131131
@staticmethod
132132
def get() -> "MonarchContext":
133133
c = _context.get(None)
134134
if c is None:
135135
mb = Mailbox.root_client_mailbox()
136136
proc_id = mb.actor_id.proc_id
137-
c = MonarchContext(mb, proc_id, Point(0, singleton_shape), (None, 0))
137+
c = MonarchContext(
138+
mb, proc_id, Point(0, singleton_shape), (None, 0), None, None
139+
)
138140
_context.set(c)
139141
return c
140142

@@ -778,7 +780,12 @@ async def handle(
778780
if ctx is None:
779781
# we reuse ctx across the actor so that send_queue is preserved between calls.
780782
ctx = self._ctx = MonarchContext(
781-
mailbox, mailbox.actor_id.proc_id, Point(rank, shape), (None, 0)
783+
mailbox,
784+
mailbox.actor_id.proc_id,
785+
Point(rank, shape),
786+
(None, 0),
787+
None,
788+
None,
782789
)
783790
ctx.mailbox = mailbox
784791
ctx.proc_id = mailbox.actor_id.proc_id
@@ -791,7 +798,10 @@ async def handle(
791798

792799
match method:
793800
case MethodSpecifier.Init():
794-
Class, *args = args
801+
Class, proc_mesh, controller_controller, *args = args
802+
ctx.controller_controller = controller_controller
803+
ctx.proc_mesh = proc_mesh
804+
_context.set(ctx)
795805
try:
796806
self.instance = Class(*args, **kwargs)
797807
except Exception as e:
@@ -885,7 +895,7 @@ def _maybe_exit_debugger(self, do_continue=True) -> None:
885895
DebugContext.set(DebugContext())
886896

887897
def _post_mortem_debug(self, exc_tb) -> None:
888-
from monarch._src.actor.debugger import DebugManager
898+
from monarch._src.actor.debugger import debug_controller
889899

890900
if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
891901
with fake_sync_state():
@@ -894,7 +904,7 @@ def _post_mortem_debug(self, exc_tb) -> None:
894904
ctx.point.rank,
895905
ctx.point.shape.coordinates(ctx.point.rank),
896906
ctx.mailbox.actor_id,
897-
DebugManager.ref().get_debug_client.call_one().get(),
907+
debug_controller(),
898908
)
899909
DebugContext.set(DebugContext(pdb_wrapper))
900910
pdb_wrapper.post_mortem(exc_tb)
@@ -1015,6 +1025,7 @@ def _create(
10151025
mailbox: Mailbox,
10161026
shape: Shape,
10171027
proc_mesh: "ProcMesh",
1028+
controller_controller: Optional["_ControllerController"],
10181029
# args and kwargs are passed to the __init__ method of the user defined
10191030
# python actor object.
10201031
*args: Any,
@@ -1038,7 +1049,7 @@ async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
10381049
None,
10391050
False,
10401051
)
1041-
send(ep, (mesh._class, *args), kwargs)
1052+
send(ep, (mesh._class, proc_mesh, controller_controller, *args), kwargs)
10421053

10431054
return mesh
10441055

python/monarch/_src/actor/debugger.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@
1414
from dataclasses import dataclass
1515
from typing import cast, Dict, Generator, List, Optional, Tuple, Union
1616

17-
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
18-
from monarch._src.actor.actor_mesh import Actor, ActorMesh, DebugContext, MonarchContext
17+
from monarch._src.actor.actor_mesh import Actor, DebugContext, MonarchContext
1918
from monarch._src.actor.endpoint import endpoint
2019
from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper
20+
from monarch._src.actor.proc_mesh import get_or_spawn_controller
2121
from monarch._src.actor.sync_state import fake_sync_state
2222
from tabulate import tabulate
2323

2424

2525
logger = logging.getLogger(__name__)
2626

27-
_DEBUG_MANAGER_ACTOR_NAME = "debug_manager"
28-
2927

3028
async def _debugger_input(prompt=""):
3129
return await asyncio.to_thread(input, prompt)
@@ -424,7 +422,7 @@ class Cast(DebugCommand):
424422
command: str
425423

426424

427-
class DebugClient(Actor):
425+
class DebugController(Actor):
428426
"""
429427
Single actor for both remote debuggers and users to talk to.
430428
@@ -563,26 +561,12 @@ async def debugger_write(
563561
await self.sessions.get(actor_name, rank).debugger_write(write)
564562

565563

566-
class DebugManager(Actor):
567-
@staticmethod
568-
@functools.cache
569-
def ref() -> "DebugManager":
570-
ctx = MonarchContext.get()
571-
return cast(
572-
DebugManager,
573-
ActorMesh.from_actor_id(
574-
DebugManager,
575-
ActorId.from_string(f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]"),
576-
ctx.mailbox,
577-
),
578-
)
579-
580-
def __init__(self, debug_client: DebugClient) -> None:
581-
self._debug_client = debug_client
582-
583-
@endpoint
584-
def get_debug_client(self) -> DebugClient:
585-
return self._debug_client
564+
# Cached so that we don't have to call out to the root client every time,
565+
# which may be on a different host.
566+
@functools.cache
567+
def debug_controller() -> DebugController:
568+
with fake_sync_state():
569+
return get_or_spawn_controller("debug_controller", DebugController).get()
586570

587571

588572
def remote_breakpointhook():
@@ -602,14 +586,12 @@ def remote_breakpointhook():
602586
"exists on both your client and worker processes."
603587
)
604588

605-
with fake_sync_state():
606-
manager = DebugManager.ref().get_debug_client.call_one().get()
607589
ctx = MonarchContext.get()
608590
pdb_wrapper = PdbWrapper(
609591
ctx.point.rank,
610592
ctx.point.shape.coordinates(ctx.point.rank),
611593
ctx.mailbox.actor_id,
612-
manager,
594+
debug_controller(),
613595
)
614596
DebugContext.set(DebugContext(pdb_wrapper))
615597
pdb_wrapper.set_trace(frame)

python/monarch/_src/actor/pdb_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from monarch._src.actor.sync_state import fake_sync_state
2020

2121
if TYPE_CHECKING:
22-
from monarch._src.actor.debugger import DebugClient
22+
from monarch._src.actor.debugger import DebugController
2323

2424

2525
@dataclass
@@ -35,7 +35,7 @@ def __init__(
3535
rank: int,
3636
coords: Dict[str, int],
3737
actor_id: ActorId,
38-
client_ref: "DebugClient",
38+
client_ref: "DebugController",
3939
header: str | None = None,
4040
):
4141
self.rank = rank

0 commit comments

Comments
 (0)